diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 85fd843b86a8..b79be3d12bfc 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -74,13 +74,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255}; -constexpr bool checkForwardIndex( - uint8_t idx, - const Base64::Charset& charset, - const Base64::ReverseIndex& table) { - return (table[static_cast(charset[idx])] == idx) && - (idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true); -} // Verify that for every entry in kBase64Charset, the corresponding entry // in kBase64ReverseIndexTable is correct. static_assert( @@ -324,51 +317,49 @@ uint8_t Base64::Base64ReverseLookup( size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { - return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true); + return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); } -// static -size_t -Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) { +size_t Base64::calculateDecodedSize(const char* data, size_t& size) { if (size == 0) { return 0; } - auto needed = (size / 4) * 3; - if (withPadding) { - // If the pad characters are included then the source string must be a - // multiple of 4 and we can query the end of the string to see how much - // padding exists. - if (size % 4 != 0) { - throw Base64Exception( + // Check if the input data is padded + if (isPadded(data, size)) { + /// If padded, ensure that the string length is a multiple of the encoded + /// block size + if (size % kEncodedBlockSize != 0) { + throw EncoderException( "Base64::decode() - invalid input string: " - "string length is not multiple of 4."); + "string length is not a multiple of the encoded block size."); } + auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; auto padding = countPadding(data, size); size -= padding; - return needed - padding; - } - // If padding doesn't exist we need to calculate it from the size - if the - // size % 4 is 0 then we have an even multiple 3 byte chunks in the result - // if it is 2 then we need 1 more byte in the output. If it is 3 then we - // need 2 more bytes in the output. It should never be 1. - auto extra = size % 4; - if (extra) { - if (extra == 1) { - throw Base64Exception( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4."); + // Adjust the needed size for padding + return needed - + ceil((padding * kBinaryBlockSize) / + static_cast(kEncodedBlockSize)); + } else { + // If not padded, Calculate extra bytes, if any + auto extra = size % kEncodedBlockSize; + auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; + + // Adjust the needed size for extra bytes, if present + if (extra) { + if (extra == 1) { + throw EncoderException( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4."); + } + needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; } - return needed + extra - 1; - } - // Just because we don't need the pad, doesn't mean it is not there. The - // URL decoder should be able to handle the original encoding. - auto padding = countPadding(data, size); - size -= padding; - return needed - padding; + return needed; + } } size_t Base64::decodeImpl( @@ -376,15 +367,14 @@ size_t Base64::decodeImpl( size_t src_len, char* dst, size_t dst_len, - const Base64::ReverseIndex& reverse_lookup, - bool include_pad) { + const ReverseIndex& reverse_lookup) { if (!src_len) { return 0; } - auto needed = calculateDecodedSize(src, src_len, include_pad); + auto needed = calculateDecodedSize(src, src_len); if (dst_len < needed) { - throw Base64Exception( + throw EncoderException( "Base64::decode() - invalid output string: " "output string is too small."); } @@ -394,10 +384,10 @@ size_t Base64::decodeImpl( // Each character of the 4 encode 6 bits of the original, grab each with // the appropriate shifts to rebuild the original and then split that back // into the original 8 bit bytes. - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12) | - (Base64ReverseLookup(src[2], reverse_lookup) << 6) | - Base64ReverseLookup(src[3], reverse_lookup); + uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase, src[1], reverse_lookup) << 12) | + (baseReverseLookup(kBase, src[2], reverse_lookup) << 6) | + baseReverseLookup(kBase, src[3], reverse_lookup); dst[0] = (last >> 16) & 0xff; dst[1] = (last >> 8) & 0xff; dst[2] = last & 0xff; @@ -406,14 +396,14 @@ size_t Base64::decodeImpl( // Handle the last 2-4 characters. This is similar to the above, but the // last 2 characters may or may not exist. DCHECK(src_len >= 2); - uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) | - (Base64ReverseLookup(src[1], reverse_lookup) << 12); + uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) | + (baseReverseLookup(kBase, src[1], reverse_lookup) << 12); dst[0] = (last >> 16) & 0xff; if (src_len > 2) { - last |= Base64ReverseLookup(src[2], reverse_lookup) << 6; + last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 6; dst[1] = (last >> 8) & 0xff; if (src_len > 3) { - last |= Base64ReverseLookup(src[3], reverse_lookup); + last |= baseReverseLookup(kBase, src[3], reverse_lookup); dst[2] = last & 0xff; } } @@ -437,9 +427,8 @@ void Base64::decodeUrl( const char* src, size_t src_len, char* dst, - size_t dst_len, - bool hasPad) { - decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad); + size_t dst_len) { + decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable); } std::string Base64::decodeUrl(folly::StringPiece encoded) { @@ -458,8 +447,7 @@ void Base64::decodeUrl( payload.second, &output[0], out_len, - kBase64UrlReverseIndexTable, - false); + kBase64UrlReverseIndexTable); output.resize(out_len); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 9888d97e67c5..a5a000c713cd 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -21,6 +21,7 @@ #include #include +#include "velox/common/encode/EncoderUtils.h" namespace facebook::velox::encoding { @@ -59,8 +60,7 @@ class Base64 { /// Returns decoded size for the specified input. Adjusts the 'size' to /// subtract the length of the padding, if exists. - static size_t - calculateDecodedSize(const char* data, size_t& size, bool withPadding = true); + static size_t calculateDecodedSize(const char* data, size_t& size); /// Decodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -69,7 +69,7 @@ class Base64 { static void decode( const std::pair& payload, - std::string& outp); + std::string& output); /// Encodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -89,12 +89,8 @@ class Base64 { static size_t decode(const char* src, size_t src_len, char* dst, size_t dst_len); - static void decodeUrl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, - bool pad); + static void + decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); constexpr static char kBase64Pad = '='; @@ -122,8 +118,18 @@ class Base64 { size_t src_len, char* dst, size_t dst_len, - const ReverseIndex& table, - bool include_pad); + const ReverseIndex& table); + + public: + // Encoding base to be used. + constexpr static int kBase = 64; + + private: + // Size of the binary block before encoding. + constexpr static int kBinaryBlockSize = 3; + + // Size of the encoded block after encoding. + constexpr static int kEncodedBlockSize = 4; }; } // namespace facebook::velox::encoding diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 5c8d4bff0636..03066be0e398 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -292,8 +292,9 @@ struct FromBase64Function { auto inputSize = input.size(); result.resize( encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode(input.data(), input.size(), result.data()); - } catch (const encoding::Base64Exception& e) { + encoding::Base64::decode( + input.data(), inputSize, result.data(), result.size()); + } catch (const encoding::EncoderException& e) { VELOX_USER_FAIL(e.what()); } } @@ -306,15 +307,11 @@ struct FromBase64UrlFunction { FOLLY_ALWAYS_INLINE void call( out_type& result, const arg_type& input) { - auto inputData = input.data(); auto inputSize = input.size(); - bool hasPad = - inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad); result.resize( - encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad)); - hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad. + encoding::Base64::calculateDecodedSize(input.data(), inputSize)); encoding::Base64::decodeUrl( - inputData, inputSize, result.data(), result.size(), hasPad); + input.data(), inputSize, result.data(), result.size()); } }; diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index 0a647d2ecf52..a908abd17a43 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -424,11 +424,16 @@ TEST_F(BinaryFunctionsTest, fromBase64) { EXPECT_EQ(std::nullopt, fromBase64(std::nullopt)); EXPECT_EQ("", fromBase64("")); EXPECT_EQ("a", fromBase64("YQ==")); + EXPECT_EQ("ab", fromBase64("YWI=")); EXPECT_EQ("abc", fromBase64("YWJj")); EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ=")); EXPECT_EQ( "Hello World from Velox!", fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE=")); + // Check encoded strings without padding + EXPECT_EQ("a", fromBase64("YQ")); + EXPECT_EQ("ab", fromBase64("YWI")); + EXPECT_EQ("abcd", fromBase64("YWJjZA")); EXPECT_THROW(fromBase64("YQ="), VeloxUserError); EXPECT_THROW(fromBase64("YQ==="), VeloxUserError);