diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 4135935189ba..30c05341d0e2 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -20,6 +20,8 @@ #include #include +#include "velox/common/base/Exceptions.h" + namespace facebook::velox::encoding { // Constants defining the size in bytes of binary and encoded blocks for Base64 @@ -29,12 +31,14 @@ constexpr static int kBinaryBlockByteSize = 3; // Size of an encoded block in bytes (4 bytes = 24 bits) constexpr static int kEncodedBlockByteSize = 4; +// Character sets for Base64 and Base64 URL encoding constexpr const Base64::Charset kBase64Charset = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; + constexpr const Base64::Charset kBase64UrlCharset = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', @@ -42,6 +46,7 @@ constexpr const Base64::Charset kBase64UrlCharset = { 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'}; +// Reverse lookup tables for decoding constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, @@ -61,6 +66,7 @@ constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; + constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, @@ -81,76 +87,64 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -constexpr bool checkForwardIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true); -} // Verify that for every entry in kBase64Charset, the corresponding entry // in kBase64ReverseIndexTable is correct. static_assert( - checkForwardIndex( + Base64::checkForwardIndex( sizeof(kBase64Charset) - 1, kBase64Charset, kBase64ReverseIndexTable), "kBase64Charset has incorrect entries"); + // Verify that for every entry in kBase64UrlCharset, the corresponding entry // in kBase64UrlReverseIndexTable is correct. static_assert( - checkForwardIndex( + Base64::checkForwardIndex( sizeof(kBase64UrlCharset) - 1, kBase64UrlCharset, kBase64UrlReverseIndexTable), "kBase64UrlCharset has incorrect entries"); -// Similar to strchr(), but for null-terminated const strings. -// Another difference is that we do not consider "\0" to be present in the -// string. -// Returns true if "str" contains the character c. -constexpr bool constCharsetContains( + +// static +const bool Base64::findCharacterInCharSet( const Base64::Charset& charset, uint8_t idx, const char c) { - return idx < charset.size() && - ((charset[idx] == c) || constCharsetContains(charset, idx + 1, c)); -} -constexpr bool checkReverseIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[idx] == 255 - ? !constCharsetContains(charset, 0, static_cast(idx)) - : (charset[table[idx]] == idx)) && - (idx > 0 ? checkReverseIndex(idx - 1, charset, table) : true); + for (; idx < charset.size(); ++idx) { + if (charset[idx] == c) { + return true; + } + } + return false; } + // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. static_assert( - checkReverseIndex( + Base64::checkReverseIndex( sizeof(kBase64ReverseIndexTable) - 1, kBase64Charset, kBase64ReverseIndexTable), "kBase64ReverseIndexTable has incorrect entries."); + // Verify that for every entry in kBase64ReverseIndexTable, the corresponding // entry in kBase64Charset is correct. -// We can't run this check as the URL version has two duplicate entries so that -// the url decoder can handle url encodings and default encodings -// static_assert( -// checkReverseIndex( -// sizeof(kBase64UrlReverseIndexTable) - 1, -// kBase64UrlCharset, -// kBase64UrlReverseIndexTable), -// "kBase64UrlReverseIndexTable has incorrect entries."); +static_assert( + Base64::checkReverseIndex( + sizeof(kBase64UrlReverseIndexTable) - 1, + kBase64UrlCharset, + kBase64UrlReverseIndexTable), + "kBase64UrlReverseIndexTable has incorrect entries."); +// Implementation of Base64 encoding and decoding functions. template -/* static */ std::string -Base64::encodeImpl(const T& data, const Charset& charset, bool include_pad) { +/* static */ std::string Base64::encodeImpl( + const T& data, + const Base64::Charset& charset, + bool include_pad) { size_t outlen = calculateEncodedSize(data.size(), include_pad); - std::string out; out.resize(outlen); - encodeImpl(data, charset, include_pad, out.data()); return out; } @@ -183,7 +177,7 @@ void Base64::encodeUrl(const char* data, size_t len, char* output) { template /* static */ void Base64::encodeImpl( const T& data, - const Charset& charset, + const Base64::Charset& charset, bool include_pad, char* out) { auto len = data.size(); @@ -218,22 +212,24 @@ template *wp++ = charset[(curr >> 12) & 0x3f]; *wp++ = charset[(curr >> 6) & 0x3f]; if (include_pad) { - *wp = kBase64Pad; + *wp = kPadding; } } else { *wp++ = charset[(curr >> 12) & 0x3f]; if (include_pad) { - *wp++ = kBase64Pad; - *wp = kBase64Pad; + *wp++ = kPadding; + *wp = kPadding; } } } } +// static std::string Base64::encode(folly::StringPiece text) { return encodeImpl(text, kBase64Charset, true); } +// static std::string Base64::encode(const char* data, size_t len) { return encode(folly::StringPiece(data, len)); } @@ -284,24 +280,19 @@ class IOBufWrapper { } // namespace +// static std::string Base64::encode(const folly::IOBuf* data) { return encodeImpl(IOBufWrapper(data), kBase64Charset, true); } -void Base64::encodeAppend(folly::StringPiece text, std::string& out) { - size_t outlen = calculateEncodedSize(text.size(), true); - - size_t initialLen = out.size(); - out.resize(initialLen + outlen); - encodeImpl(text, kBase64Charset, true, out.data() + initialLen); -} - +// static std::string Base64::decode(folly::StringPiece encoded) { std::string output; Base64::decode(std::make_pair(encoded.data(), encoded.size()), output); return output; } +// static void Base64::decode( const std::pair& payload, std::string& output) { @@ -316,18 +307,18 @@ void Base64::decode(const char* data, size_t size, char* output) { Base64::decode(data, size, output, out_len); } -uint8_t Base64::Base64ReverseLookup( +// static +uint8_t Base64::base64ReverseLookup( char p, - const Base64::ReverseIndex& reverse_lookup) { - auto curr = reverse_lookup[(uint8_t)p]; + const Base64::ReverseIndex& reverseIndex) { + auto curr = reverseIndex[(uint8_t)p]; if (curr >= 0x40) { - throw Base64Exception( - "Base64::decode() - invalid input string: invalid characters"); + VELOX_USER_FAIL("decode() - invalid input string: invalid characters"); } - return curr; } +// static size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); @@ -344,13 +335,13 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { // If padded, ensure that the string length is a multiple of the encoded // block size if (size % kEncodedBlockByteSize != 0) { - throw Base64Exception( + VELOX_USER_FAIL( "Base64::decode() - invalid input string: " "string length is not a multiple of 4."); } auto needed = (size * kBinaryBlockByteSize) / kEncodedBlockByteSize; - auto padding = countPadding(data, size); + auto padding = numPadding(data, size); size -= padding; // Adjust the needed size by deducting the bytes corresponding to the @@ -366,7 +357,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { // Adjust the needed size for extra bytes, if present if (extra) { if (extra == 1) { - throw Base64Exception( + VELOX_USER_FAIL( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } @@ -376,19 +367,20 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) { return needed; } +// static size_t Base64::decodeImpl( const char* src, size_t src_len, char* dst, size_t dst_len, - const ReverseIndex& reverse_lookup) { + const Base64::ReverseIndex& reverseIndex) { if (!src_len) { return 0; } auto needed = calculateDecodedSize(src, src_len); if (dst_len < needed) { - throw Base64Exception( + VELOX_USER_FAIL( "Base64::decode() - invalid output string: " "output string is too small."); } @@ -398,10 +390,10 @@ size_t Base64::decodeImpl( // Each character of the 4 encode 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8 bit bytes. - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12) | - (Base64ReverseLookup(src[2], reverse_lookup) << 6) | - Base64ReverseLookup(src[3], reverse_lookup); + uint32_t last = (Base64::base64ReverseLookup(src[0], reverseIndex) << 18) | + (Base64::base64ReverseLookup(src[1], reverseIndex) << 12) | + (Base64::base64ReverseLookup(src[2], reverseIndex) << 6) | + Base64::base64ReverseLookup(src[3], reverseIndex); dst[0] = (last >> 16) & 0xff; dst[1] = (last >> 8) & 0xff; dst[2] = last & 0xff; @@ -410,14 +402,14 @@ size_t Base64::decodeImpl( // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(src_len >= 2); - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12); + uint32_t last = (Base64::base64ReverseLookup(src[0], reverseIndex) << 18) | + (Base64::base64ReverseLookup(src[1], reverseIndex) << 12); dst[0] = (last >> 16) & 0xff; if (src_len > 2) { - last |= Base64ReverseLookup(src[2], reverse_lookup) << 6; + last |= Base64::base64ReverseLookup(src[2], reverseIndex) << 6; dst[1] = (last >> 8) & 0xff; if (src_len > 3) { - last |= Base64ReverseLookup(src[3], reverse_lookup); + last |= Base64::base64ReverseLookup(src[3], reverseIndex); dst[2] = last & 0xff; } } @@ -425,18 +417,22 @@ size_t Base64::decodeImpl( return needed; } +// static std::string Base64::encodeUrl(folly::StringPiece text) { return encodeImpl(text, kBase64UrlCharset, false); } +// static std::string Base64::encodeUrl(const char* data, size_t len) { return encodeUrl(folly::StringPiece(data, len)); } +// static std::string Base64::encodeUrl(const folly::IOBuf* data) { return encodeImpl(IOBufWrapper(data), kBase64UrlCharset, false); } +// static void Base64::decodeUrl( const char* src, size_t src_len, @@ -445,12 +441,14 @@ void Base64::decodeUrl( decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable); } +// static std::string Base64::decodeUrl(folly::StringPiece encoded) { std::string output; Base64::decodeUrl(std::make_pair(encoded.data(), encoded.size()), output); return output; } +// static void Base64::decodeUrl( const std::pair& payload, std::string& output) { @@ -464,4 +462,5 @@ void Base64::decodeUrl( kBase64UrlReverseIndexTable); output.resize(out_len); } + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 2c7de463ea6f..e6c7ceae26bd 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -15,6 +15,7 @@ */ #pragma once +#include #include #include #include @@ -24,24 +25,30 @@ namespace facebook::velox::encoding { -class Base64Exception : public std::exception { +class Base64 { public: - explicit Base64Exception(const char* msg) : msg_(msg) {} - const char* what() const noexcept override { - return msg_; - } + static const size_t kCharsetSize = 64; + static const size_t kReverseIndexSize = 256; - protected: - const char* msg_; -}; + /// Character set used for encoding purposes. + /// Contains specific characters that form the encoding scheme. + using Charset = std::array; -class Base64 { - public: - using Charset = std::array; - using ReverseIndex = std::array; + /// Reverse lookup table for decoding purposes. + /// Maps each possible encoded character to its corresponding numeric value + /// within the encoding base. + using ReverseIndex = std::array; + + /// Padding character used in encoding. + static const char kPadding = '='; + /// Encodes the specified number of characters from the 'data'. static std::string encode(const char* data, size_t len); + + /// Encodes the specified text. static std::string encode(folly::StringPiece text); + + /// Encodes the specified IOBuf data. static std::string encode(const folly::IOBuf* text); /// Returns encoded size for the input of the specified size. @@ -52,9 +59,7 @@ class Base64 { /// returned by the calculateEncodedSize(). static void encode(const char* data, size_t size, char* output); - // Appends the encoded text to out. - static void encodeAppend(folly::StringPiece text, std::string& out); - + /// Decodes the specified encoded text. static std::string decode(folly::StringPiece encoded); /// Returns the actual size of the decoded data. Will also remove the padding @@ -71,49 +76,104 @@ class Base64 { std::string& output); /// Encodes the specified number of characters from the 'data' and writes the - /// result to the 'output'. The output must have enough space, e.g. as - /// returned by the calculateEncodedSize(). + /// result to the 'output' using URL encoding. The output must have enough + /// space as returned by the calculateEncodedSize(). static void encodeUrl(const char* data, size_t size, char* output); - // compatible with www's Base64URL::encode/decode - // TODO rename encode_url/decode_url to encodeUrl/encodeUrl. + /// Encodes the specified number of characters from the 'data' using URL + /// encoding. static std::string encodeUrl(const char* data, size_t len); + + /// Encodes the specified IOBuf data using URL encoding. static std::string encodeUrl(const folly::IOBuf* data); + + /// Encodes the specified text using URL encoding. static std::string encodeUrl(folly::StringPiece text); + + /// Decodes the specified URL encoded payload and writes the result to the + /// 'output'. static void decodeUrl( const std::pair& payload, std::string& output); + + /// Decodes the specified URL encoded text. static std::string decodeUrl(folly::StringPiece text); + /// Decodes the specified number of characters from the 'src' and writes the + /// result to the 'dst'. static size_t decode(const char* src, size_t src_len, char* dst, size_t dst_len); + /// Decodes the specified number of characters from the 'src' using URL + /// encoding and writes the result to the 'dst'. static void decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); - constexpr static char kBase64Pad = '='; - - private: + /// Checks if there is padding in encoded data. static inline bool isPadded(const char* data, size_t len) { - return (len > 0 && data[len - 1] == kBase64Pad); + return (len > 0 && data[len - 1] == kPadding); } - static inline size_t countPadding(const char* src, size_t len) { + /// Counts the number of padding characters in encoded data. + static inline size_t numPadding(const char* src, size_t len) { size_t numPadding{0}; - while (len > 0 && src[len - 1] == kBase64Pad) { + while (len > 0 && src[len - 1] == kPadding) { numPadding++; len--; } - return numPadding; } - static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table); + // Validate the character in charset with ReverseIndex table + static constexpr bool checkForwardIndex( + uint8_t idx, + const Charset& charset, + const ReverseIndex& reverseIndex) { + for (uint8_t i = 0; i <= idx; ++i) { + if (!(reverseIndex[static_cast(charset[i])] == i)) { + return false; + } + } + return true; + } + + /// Searches for a character within a charset up to a certain index. + static const bool + findCharacterInCharSet(const Charset& charset, uint8_t idx, const char c); + /// Checks the consistency of a reverse index mapping for a given character + /// set. + static constexpr bool checkReverseIndex( + uint8_t idx, + const Charset& charset, + const ReverseIndex& reverseIndex) { + for (uint8_t currentIdx = idx; currentIdx != static_cast(-1); + --currentIdx) { + if (reverseIndex[currentIdx] == 255) { + if (Base64::findCharacterInCharSet( + charset, 0, static_cast(currentIdx))) { + return false; + } + } else { + if (!(charset[reverseIndex[currentIdx]] == currentIdx)) { + return false; + } + } + } + return true; + } + + /// Performs a reverse lookup in the reverse index to retrieve the original + /// index of a character in the base. + static uint8_t base64ReverseLookup(char p, const ReverseIndex& reverseIndex); + + private: + /// Encodes the specified data using the provided charset. template static std::string encodeImpl(const T& data, const Charset& charset, bool include_pad); + /// Encodes the specified data using the provided charset. template static void encodeImpl( const T& data, @@ -121,6 +181,7 @@ class Base64 { bool include_pad, char* out); + /// Decodes the specified data using the provided reverse lookup table. static size_t decodeImpl( const char* src, size_t src_len, diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 15556583c751..2098219ec0de 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -18,6 +18,8 @@ #include #include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/base/Exceptions.h" + namespace facebook::velox::encoding { class Base64Test : public ::testing::Test {}; @@ -61,7 +63,7 @@ TEST_F(Base64Test, calculateDecodedSizeProperSize) { encoded_size = 21; EXPECT_THROW( Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), - facebook::velox::encoding::Base64Exception); + VeloxUserError); encoded_size = 32; EXPECT_EQ( @@ -86,4 +88,56 @@ TEST_F(Base64Test, calculateDecodedSizeProperSize) { EXPECT_EQ(14, encoded_size); } +TEST_F(Base64Test, ChecksPadding) { + EXPECT_TRUE(Base64::isPadded("ABC=", 4)); + EXPECT_FALSE(Base64::isPadded("ABC", 3)); +} + +TEST_F(Base64Test, CountsPaddingCorrectly) { + EXPECT_EQ(0, Base64::numPadding("ABC", 3)); + EXPECT_EQ(1, Base64::numPadding("ABC=", 4)); + EXPECT_EQ(2, Base64::numPadding("AB==", 4)); +} + +constexpr Base64::Charset testCharset = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'}; + +constexpr Base64::ReverseIndex testReverseIndex = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +TEST_F(Base64Test, HandlesLookupAndExceptions) { + EXPECT_NO_THROW(Base64::base64ReverseLookup('A', testReverseIndex)); + EXPECT_THROW( + Base64::base64ReverseLookup('=', testReverseIndex), VeloxUserError); +} + +TEST_F(Base64Test, ValidatesCharsetWithReverseIndex) { + EXPECT_TRUE(Base64::checkForwardIndex(63, testCharset, testReverseIndex)); +} + +TEST_F(Base64Test, ValidatesReverseIndexWithCharset) { + EXPECT_TRUE(Base64::checkReverseIndex(255, testCharset, testReverseIndex)); +} + } // namespace facebook::velox::encoding