diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 85fd843b86a8..8014362932c7 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -22,6 +22,10 @@ namespace facebook::velox::encoding { +// Constants defining the size of binary and encoded blocks for Base64 encoding. +constexpr static int kBinaryBlockSize = 3; // 3 bytes of binary = 24 bits +constexpr static int kEncodedBlockSize = 4; // 4 bytes of encoded = 24 bits + constexpr const Base64::Charset kBase64Charset = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', @@ -298,10 +302,9 @@ std::string Base64::decode(folly::StringPiece encoded) { void Base64::decode( const std::pair& payload, std::string& output) { - size_t out_len = payload.second / 4 * 3; - output.resize(out_len, '\0'); - out_len = Base64::decode(payload.first, payload.second, &output[0], out_len); - output.resize(out_len); + size_t inputSize = payload.second; + output.resize(calculateDecodedSize(payload.first, inputSize)); + decode(payload.first, inputSize, output.data(), output.size()); } // static @@ -324,51 +327,49 @@ uint8_t Base64::Base64ReverseLookup( size_t Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { - return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true); + return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable); } // static -size_t -Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) { +size_t Base64::calculateDecodedSize(const char* data, size_t& size) { if (size == 0) { return 0; } - auto needed = (size / 4) * 3; - if (withPadding) { - // If the pad characters are included then the source string must be a - // multiple of 4 and we can query the end of the string to see how much - // padding exists. - if (size % 4 != 0) { + // Check if the input data is padded + if (isPadded(data, size)) { + // If padded, ensure that the string length is a multiple of the encoded + // block size + if (size % kEncodedBlockSize != 0) { throw Base64Exception( "Base64::decode() - invalid input string: " - "string length is not multiple of 4."); + "string length is not a multiple of 4."); } + auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; auto padding = countPadding(data, size); size -= padding; - return needed - padding; + + // Adjust the needed size for padding + return needed - + ceil((padding * kBinaryBlockSize) / + static_cast(kEncodedBlockSize)); } + // If not padded, Calculate extra bytes, if any + auto extra = size % kEncodedBlockSize; + auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; - // If padding doesn't exist we need to calculate it from the size - if the - // size % 4 is 0 then we have an even multiple 3 byte chunks in the result - // if it is 2 then we need 1 more byte in the output. If it is 3 then we - // need 2 more bytes in the output. It should never be 1. - auto extra = size % 4; + // Adjust the needed size for extra bytes, if present if (extra) { if (extra == 1) { throw Base64Exception( "Base64::decode() - invalid input string: " "string length cannot be 1 more than a multiple of 4."); } - return needed + extra - 1; + needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; } - // Just because we don't need the pad, doesn't mean it is not there. The - // URL decoder should be able to handle the original encoding. - auto padding = countPadding(data, size); - size -= padding; - return needed - padding; + return needed; } size_t Base64::decodeImpl( @@ -376,13 +377,12 @@ size_t Base64::decodeImpl( size_t src_len, char* dst, size_t dst_len, - const Base64::ReverseIndex& reverse_lookup, - bool include_pad) { + const ReverseIndex& reverse_lookup) { if (!src_len) { return 0; } - auto needed = calculateDecodedSize(src, src_len, include_pad); + auto needed = calculateDecodedSize(src, src_len); if (dst_len < needed) { throw Base64Exception( "Base64::decode() - invalid output string: " @@ -437,9 +437,8 @@ void Base64::decodeUrl( const char* src, size_t src_len, char* dst, - size_t dst_len, - bool hasPad) { - decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad); + size_t dst_len) { + decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable); } std::string Base64::decodeUrl(folly::StringPiece encoded) { @@ -458,8 +457,7 @@ void Base64::decodeUrl( payload.second, &output[0], out_len, - kBase64UrlReverseIndexTable, - false); + kBase64UrlReverseIndexTable); output.resize(out_len); } } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 9888d97e67c5..72427752e5ee 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -57,10 +57,9 @@ class Base64 { static std::string decode(folly::StringPiece encoded); - /// Returns decoded size for the specified input. Adjusts the 'size' to - /// subtract the length of the padding, if exists. - static size_t - calculateDecodedSize(const char* data, size_t& size, bool withPadding = true); + /// Returns the actual size of the decoded data. Will also remove the padding + /// length from the input data size. + static size_t calculateDecodedSize(const char* data, size_t& size); /// Decodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -69,7 +68,7 @@ class Base64 { static void decode( const std::pair& payload, - std::string& outp); + std::string& output); /// Encodes the specified number of characters from the 'data' and writes the /// result to the 'output'. The output must have enough space, e.g. as @@ -89,19 +88,24 @@ class Base64 { static size_t decode(const char* src, size_t src_len, char* dst, size_t dst_len); - static void decodeUrl( - const char* src, - size_t src_len, - char* dst, - size_t dst_len, - bool pad); + static void + decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len); constexpr static char kBase64Pad = '='; private: + static inline bool isPadded(const char* data, size_t len) { + return (len > 0 && data[len - 1] == kBase64Pad); + } + static inline size_t countPadding(const char* src, size_t len) { - DCHECK_GE(len, 2); - return src[len - 1] != kBase64Pad ? 0 : src[len - 2] != kBase64Pad ? 1 : 2; + size_t numPadding{0}; + while (len > 0 && src[len - 1] == kBase64Pad) { + numPadding++; + len--; + } + + return numPadding; } static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table); @@ -122,8 +126,7 @@ class Base64 { size_t src_len, char* dst, size_t dst_len, - const ReverseIndex& table, - bool include_pad); + const ReverseIndex& table); }; } // namespace facebook::velox::encoding diff --git a/velox/common/encode/CMakeLists.txt b/velox/common/encode/CMakeLists.txt index d9918d53b59c..bc27527e14ac 100644 --- a/velox/common/encode/CMakeLists.txt +++ b/velox/common/encode/CMakeLists.txt @@ -12,5 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +if(${VELOX_BUILD_TESTING}) + add_subdirectory(tests) +endif() + add_library(velox_encode Base64.cpp) target_link_libraries(velox_encode PUBLIC Folly::folly) diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp new file mode 100644 index 000000000000..caa5c73db252 --- /dev/null +++ b/velox/common/encode/tests/Base64Test.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/common/encode/Base64.h" +#include +#include "velox/common/base/tests/GTestUtils.h" + +namespace facebook::velox::encoding { +class Base64Test : public ::testing::Test {}; + +TEST_F(Base64Test, fromBase64) { + EXPECT_EQ( + "Hello, World!", + Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ=="))); + EXPECT_EQ( + "Base64 encoding is fun.", + Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4="))); + EXPECT_EQ( + "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ="))); + EXPECT_EQ( + "1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA=="))); + + // Check encoded strings without padding + EXPECT_EQ( + "Hello, World!", + Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ"))); + EXPECT_EQ( + "Base64 encoding is fun.", + Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4"))); + EXPECT_EQ( + "Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ"))); + EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA"))); +} + +TEST_F(Base64Test, calculateDecodedSizeProperSize) { + size_t encoded_size{0}; + + encoded_size = 20; + EXPECT_EQ( + 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size)); + EXPECT_EQ(18, encoded_size); + + encoded_size = 18; + EXPECT_EQ( + 13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size)); + EXPECT_EQ(18, encoded_size); + + encoded_size = 21; + EXPECT_THROW( + Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size), + facebook::velox::encoding::Base64Exception); + + encoded_size = 32; + EXPECT_EQ( + 23, + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size)); + EXPECT_EQ(31, encoded_size); + + encoded_size = 31; + EXPECT_EQ( + 23, + Base64::calculateDecodedSize( + "QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size)); + EXPECT_EQ(31, encoded_size); + + encoded_size = 16; + EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size)); + EXPECT_EQ(14, encoded_size); + + encoded_size = 14; + EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size)); + EXPECT_EQ(14, encoded_size); +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt new file mode 100644 index 000000000000..11caf35c8416 --- /dev/null +++ b/velox/common/encode/tests/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_executable(velox_common_encode_test Base64Test.cpp) +add_test(velox_common_encode_test velox_common_encode_test) +target_link_libraries( + velox_common_encode_test + PUBLIC Folly::folly + PRIVATE velox_encode + velox_exception + gtest + gtest_main) diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index c3a3b96a1190..94180fe2ab3b 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -284,7 +284,6 @@ struct ToBase64Function { template struct FromBase64Function { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( out_type& result, const arg_type& input) { @@ -292,7 +291,8 @@ struct FromBase64Function { auto inputSize = input.size(); result.resize( encoding::Base64::calculateDecodedSize(input.data(), inputSize)); - encoding::Base64::decode(input.data(), input.size(), result.data()); + encoding::Base64::decode( + input.data(), inputSize, result.data(), result.size()); } catch (const encoding::Base64Exception& e) { VELOX_USER_FAIL(e.what()); } @@ -302,19 +302,14 @@ struct FromBase64Function { template struct FromBase64UrlFunction { VELOX_DEFINE_FUNCTION_TYPES(T); - FOLLY_ALWAYS_INLINE void call( out_type& result, const arg_type& input) { - auto inputData = input.data(); auto inputSize = input.size(); - bool hasPad = - inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad); result.resize( - encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad)); - hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad. + encoding::Base64::calculateDecodedSize(input.data(), inputSize)); encoding::Base64::decodeUrl( - inputData, inputSize, result.data(), result.size(), hasPad); + input.data(), inputSize, result.data(), result.size()); } }; diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index e690674ba431..bcd3d3cdaf59 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -425,6 +425,7 @@ TEST_F(BinaryFunctionsTest, fromBase64) { EXPECT_EQ(std::nullopt, fromBase64(std::nullopt)); EXPECT_EQ("", fromBase64("")); EXPECT_EQ("a", fromBase64("YQ==")); + EXPECT_EQ("ab", fromBase64("YWI=")); EXPECT_EQ("abc", fromBase64("YWJj")); EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ=")); EXPECT_EQ( @@ -433,6 +434,11 @@ TEST_F(BinaryFunctionsTest, fromBase64) { EXPECT_THROW(fromBase64("YQ="), VeloxUserError); EXPECT_THROW(fromBase64("YQ==="), VeloxUserError); + + // Check encoded strings without padding + EXPECT_EQ("a", fromBase64("YQ")); + EXPECT_EQ("ab", fromBase64("YWI")); + EXPECT_EQ("abcd", fromBase64("YWJjZA")); } TEST_F(BinaryFunctionsTest, fromBase64Url) {