diff --git a/velox/common/encode/Base32.cpp b/velox/common/encode/Base32.cpp index c36cb6a706b5..dcbfd46725c5 100644 --- a/velox/common/encode/Base32.cpp +++ b/velox/common/encode/Base32.cpp @@ -179,4 +179,118 @@ template } } +size_t Base32::calculateDecodedSize(const char* data, size_t& size) { + if (size == 0) { + return 0; + } + + // Check if the input data is padded + if (isPadded(data, size)) { + /// If padded, ensure that the string length is a multiple of the encoded + /// block size. + if (size % kEncodedBlockSize != 0) { + throw EncoderException( + "Base32::decode() - invalid input string: " + "string length is not a multiple of 8."); + } + + auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize; + auto padding = countPadding(data, size); + size -= padding; + + // Adjust the needed size for padding. + return needed - + ceil((padding * kBinaryBlockSize) / + static_cast(kEncodedBlockSize)); + } else { + // If not padded, calculate extra bytes, if any. + auto extra = size % kEncodedBlockSize; + auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize; + + // Adjust the needed size for extra bytes, if present. + if (extra) { + if ((extra == 6) || (extra == 3) || (extra == 1)) { + throw EncoderException( + "Base32::decode() - invalid input string: " + "string length cannot be 6, 3 or 1 more than a multiple of 8."); + } + needed += (extra * kBinaryBlockSize) / kEncodedBlockSize; + } + + return needed; + } + } + + size_t + Base32::decode(const char* src, size_t src_len, char* dst, size_t dst_len) { + return decodeImpl(src, src_len, dst, dst_len, kBase32ReverseIndexTable); + } + + size_t Base32::decodeImpl( + const char* src, + size_t src_len, + char* dst, + size_t dst_len, + const ReverseIndex& reverse_lookup) { + if (!src_len) { + return 0; + } + + auto needed = calculateDecodedSize(src, src_len); + if (dst_len < needed) { + throw EncoderException( + "Base32::decode() - invalid output string: " + "output string is too small."); + } + + // Handle full groups of 8 characters. + for (; src_len > 8; src_len -= 8, src += 8, dst += 5) { + /// Each character of the 8 bytes encode 5 bits of the original, grab each + /// with the appropriate shifts to rebuild the original and then split that + /// back into the original 8 bit bytes. + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30) | + (baseReverseLookup(kBase, src[2], reverse_lookup) << 25) | + (baseReverseLookup(kBase, src[3], reverse_lookup) << 20) | + (baseReverseLookup(kBase, src[4], reverse_lookup) << 15) | + (baseReverseLookup(kBase, src[5], reverse_lookup) << 10) | + (baseReverseLookup(kBase, src[6], reverse_lookup) << 5) | + baseReverseLookup(kBase, src[7], reverse_lookup); + dst[0] = (last >> 32) & 0xff; + dst[1] = (last >> 24) & 0xff; + dst[2] = (last >> 16) & 0xff; + dst[3] = (last >> 8) & 0xff; + dst[4] = last & 0xff; + } + + /// Handle the last 2, 4, 5, 7 or 8 characters. This is similar to the above, + /// but the last characters may or may not exist. + DCHECK(src_len >= 2); + uint64_t last = + (uint64_t(baseReverseLookup(kBase, src[0], reverse_lookup)) << 35) | + (uint64_t(baseReverseLookup(kBase, src[1], reverse_lookup)) << 30); + dst[0] = (last >> 32) & 0xff; + if (src_len > 2) { + last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 25; + last |= baseReverseLookup(kBase, src[3], reverse_lookup) << 20; + dst[1] = (last >> 24) & 0xff; + if (src_len > 4) { + last |= baseReverseLookup(kBase, src[4], reverse_lookup) << 15; + dst[2] = (last >> 16) & 0xff; + if (src_len > 5) { + last |= baseReverseLookup(kBase, src[5], reverse_lookup) << 10; + last |= baseReverseLookup(kBase, src[6], reverse_lookup) << 5; + dst[3] = (last >> 8) & 0xff; + if (src_len > 7) { + last |= baseReverseLookup(kBase, src[7], reverse_lookup); + dst[4] = last & 0xff; + } + } + } + } + + return needed; + } + } // namespace facebook::velox::encoding diff --git a/velox/common/encode/Base32.h b/velox/common/encode/Base32.h index 72d223659e03..3b5243644cc8 100644 --- a/velox/common/encode/Base32.h +++ b/velox/common/encode/Base32.h @@ -34,7 +34,26 @@ class Base32 { /// returned by the calculateEncodedSize(). static void encode(const char* data, size_t size, char* output); + /// Returns decoded size for the specified input. Adjusts the 'size' to + /// subtract the length of the padding, if exists. + static size_t calculateDecodedSize(const char* data, size_t& size); + + /// Decodes the specified number of characters from the 'src' and writes the + /// result to the 'dst'. The destination must have enough space, e.g. as + /// returned by the calculateDecodedSize(). + static size_t + decode(const char* src, size_t src_len, char* dst, size_t dst_len); + private: + /// Decodes the specified number of base 32 encoded characters from the 'src' + /// and writes to 'dst' + static size_t decodeImpl( + const char* src, + size_t src_len, + char* dst, + size_t dst_len, + const ReverseIndex& table); + template static void encodeImpl( const T& data, diff --git a/velox/common/encode/tests/Base32Test.cpp b/velox/common/encode/tests/Base32Test.cpp index 8ed3beb1e874..bf1973c56a39 100644 --- a/velox/common/encode/tests/Base32Test.cpp +++ b/velox/common/encode/tests/Base32Test.cpp @@ -39,4 +39,40 @@ TEST_F(Base32Test, calculateEncodedSizeProperSize) { EXPECT_EQ(24, Base32::calculateEncodedSize(11, true)); } +TEST_F(Base32Test, calculateDecodedSizeProperSize) { + size_t encoded_size{0}; + + encoded_size = 8; + EXPECT_EQ(1, Base32::calculateDecodedSize("ME======", encoded_size)); + EXPECT_EQ(2, encoded_size); + + encoded_size = 2; + EXPECT_EQ(1, Base32::calculateDecodedSize("ME", encoded_size)); + EXPECT_EQ(2, encoded_size); + + encoded_size = 9; + EXPECT_THROW( + Base32::calculateDecodedSize("MFRA====", encoded_size), + facebook::velox::encoding::EncoderException); + + encoded_size = 8; + EXPECT_EQ(2, Base32::calculateDecodedSize("MFRA====", encoded_size)); + EXPECT_EQ(4, encoded_size); + + encoded_size = 8; + EXPECT_EQ(3, Base32::calculateDecodedSize("MFRGG===", encoded_size)); + EXPECT_EQ(5, encoded_size); + + encoded_size = 24; + EXPECT_EQ( + 11, + Base32::calculateDecodedSize("NBSWY3DPEB3W64TMMQ======", encoded_size)); + EXPECT_EQ(18, encoded_size); + + encoded_size = 18; + EXPECT_EQ( + 11, Base32::calculateDecodedSize("NBSWY3DPEB3W64TMMQ", encoded_size)); + EXPECT_EQ(18, encoded_size); +} + } // namespace facebook::velox::encoding \ No newline at end of file diff --git a/velox/docs/functions/presto/binary.rst b/velox/docs/functions/presto/binary.rst index 3f5ad0a10bcf..20aa5298e31f 100644 --- a/velox/docs/functions/presto/binary.rst +++ b/velox/docs/functions/presto/binary.rst @@ -22,6 +22,10 @@ Binary Functions Decodes ``bigint`` value from a 64-bit 2’s complement big endian ``binary``. +.. function:: from_base32(string) -> varbinary + + Decodes binary data from the base32 encoded ``string``. + .. function:: from_hex(string) -> varbinary Decodes binary data from the hex encoded ``string``. diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index 2431d5816ebb..bba7fdd530b0 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -338,6 +338,24 @@ struct ToBase32Function { } }; +struct FromBase32Function { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call( + out_type& result, + const arg_type& input) { + try { + auto inputSize = input.size(); + result.resize( + encoding::Base32::calculateDecodedSize(input.data(), inputSize)); + encoding::Base32::decode( + input.data(), inputSize, result.data(), result.size()); + } catch (const encoding::EncoderException& e) { + VELOX_USER_FAIL(e.what()); + } + } +}; + template struct FromBigEndian32 { VELOX_DEFINE_FUNCTION_TYPES(T); diff --git a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp index 79a9697a332f..a72a4c0c6a42 100644 --- a/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/BinaryFunctionsRegistration.cpp @@ -53,6 +53,8 @@ void registerSimpleFunctions(const std::string& prefix) { {prefix + "from_base64url"}); registerFunction( {prefix + "to_base32"}); + registerFunction( + {prefix + "from_base32"}); registerFunction( {prefix + "from_big_endian_32"}); diff --git a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp index 53541dd93feb..f57b3f4e8eb7 100644 --- a/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp +++ b/velox/functions/prestosql/tests/BinaryFunctionsTest.cpp @@ -489,6 +489,42 @@ TEST_F(BinaryFunctionsTest, toBase32) { toBase32("Hello World from Velox!")); } +TEST_F(BinaryFunctionsTest, fromBase32) { + const auto fromBase32 = [&](std::optional value) { + return evaluateOnce("from_base32(c0)", value); + }; + + EXPECT_EQ(std::nullopt, fromBase32(std::nullopt)); + EXPECT_EQ("", fromBase32("")); + EXPECT_EQ("a", fromBase32("ME======")); + EXPECT_EQ("ab", fromBase32("MFRA====")); + EXPECT_EQ("abc", fromBase32("MFRGG===")); + EXPECT_EQ("db2", fromBase32("MRRDE===")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA=")); + EXPECT_EQ("hello world", fromBase32("NBSWY3DPEB3W64TMMQ======")); + EXPECT_EQ( + "Hello World from Velox!", + fromBase32("JBSWY3DPEBLW64TMMQQGM4TPNUQFMZLMN54CC===")); + + // Try encoded strings without padding + EXPECT_EQ("a", fromBase32("ME")); + EXPECT_EQ("ab", fromBase32("MFRA")); + EXPECT_EQ("abc", fromBase32("MFRGG")); + EXPECT_EQ("db2", fromBase32("MRRDE")); + EXPECT_EQ("abcd", fromBase32("MFRGGZA")); + EXPECT_EQ("1234", fromBase32("GEZDGNA")); + EXPECT_EQ("abcde", fromBase32("MFRGGZDF")); + EXPECT_EQ("abcdef", fromBase32("MFRGGZDFMY")); + + // Check with invaild encoded strings + EXPECT_THROW(fromBase32("1="), VeloxUserError); + EXPECT_THROW(fromBase32("M1======"), VeloxUserError); + + VELOX_ASSERT_THROW( + fromBase32("J1======"), + "decode() - invalid input string: invalid characters"); +} + TEST_F(BinaryFunctionsTest, fromBigEndian32) { const auto fromBigEndian32 = [&](const std::optional& arg) { return evaluateOnce(