Skip to content

Commit

Permalink
Fix from_base64 Presto function for inputs without padding
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Apr 22, 2024
1 parent d44a23e commit 8e1a45a
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 58 deletions.
66 changes: 32 additions & 34 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

namespace facebook::velox::encoding {

// Constants defining the size of binary and encoded blocks for Base64 encoding.
constexpr static int kBinaryBlockSize = 3; // 3 bytes of binary = 24 bits
constexpr static int kEncodedBlockSize = 4; // 4 bytes of encoded = 24 bits

constexpr const Base64::Charset kBase64Charset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
Expand Down Expand Up @@ -298,10 +302,9 @@ std::string Base64::decode(folly::StringPiece encoded) {
void Base64::decode(
const std::pair<const char*, int32_t>& payload,
std::string& output) {
size_t out_len = payload.second / 4 * 3;
output.resize(out_len, '\0');
out_len = Base64::decode(payload.first, payload.second, &output[0], out_len);
output.resize(out_len);
size_t inputSize = payload.second;
output.resize(calculateDecodedSize(payload.first, inputSize));
decode(payload.first, inputSize, output.data(), output.size());
}

// static
Expand All @@ -324,65 +327,62 @@ uint8_t Base64::Base64ReverseLookup(

size_t
Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) {
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true);
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable);
}

// static
size_t
Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) {
size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
if (size == 0) {
return 0;
}

auto needed = (size / 4) * 3;
if (withPadding) {
// If the pad characters are included then the source string must be a
// multiple of 4 and we can query the end of the string to see how much
// padding exists.
if (size % 4 != 0) {
// Check if the input data is padded
if (isPadded(data, size)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (size % kEncodedBlockSize != 0) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length is not multiple of 4.");
"string length is not a multiple of 4.");
}

auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize;
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;

// Adjust the needed size for padding
return needed -
ceil((padding * kBinaryBlockSize) /
static_cast<double>(kEncodedBlockSize));
}
// If not padded, Calculate extra bytes, if any
auto extra = size % kEncodedBlockSize;
auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize;

// If padding doesn't exist we need to calculate it from the size - if the
// size % 4 is 0 then we have an even multiple 3 byte chunks in the result
// if it is 2 then we need 1 more byte in the output. If it is 3 then we
// need 2 more bytes in the output. It should never be 1.
auto extra = size % 4;
// Adjust the needed size for extra bytes, if present
if (extra) {
if (extra == 1) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
return needed + extra - 1;
needed += (extra * kBinaryBlockSize) / kEncodedBlockSize;
}

// Just because we don't need the pad, doesn't mean it is not there. The
// URL decoder should be able to handle the original encoding.
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;
return needed;
}

size_t Base64::decodeImpl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
const Base64::ReverseIndex& reverse_lookup,
bool include_pad) {
const ReverseIndex& reverse_lookup) {
if (!src_len) {
return 0;
}

auto needed = calculateDecodedSize(src, src_len, include_pad);
auto needed = calculateDecodedSize(src, src_len);
if (dst_len < needed) {
throw Base64Exception(
"Base64::decode() - invalid output string: "
Expand Down Expand Up @@ -437,9 +437,8 @@ void Base64::decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool hasPad) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad);
size_t dst_len) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable);
}

std::string Base64::decodeUrl(folly::StringPiece encoded) {
Expand All @@ -458,8 +457,7 @@ void Base64::decodeUrl(
payload.second,
&output[0],
out_len,
kBase64UrlReverseIndexTable,
false);
kBase64UrlReverseIndexTable);
output.resize(out_len);
}
} // namespace facebook::velox::encoding
33 changes: 18 additions & 15 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ class Base64 {

static std::string decode(folly::StringPiece encoded);

/// Returns decoded size for the specified input. Adjusts the 'size' to
/// subtract the length of the padding, if exists.
static size_t
calculateDecodedSize(const char* data, size_t& size, bool withPadding = true);
/// Returns the actual size of the decoded data. Will also remove the padding
/// length from the input data size.
static size_t calculateDecodedSize(const char* data, size_t& size);

/// Decodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -69,7 +68,7 @@ class Base64 {

static void decode(
const std::pair<const char*, int32_t>& payload,
std::string& outp);
std::string& output);

/// Encodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -89,19 +88,24 @@ class Base64 {
static size_t
decode(const char* src, size_t src_len, char* dst, size_t dst_len);

static void decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool pad);
static void
decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len);

constexpr static char kBase64Pad = '=';

private:
static inline bool isPadded(const char* data, size_t len) {
return (len > 0 && data[len - 1] == kBase64Pad);
}

static inline size_t countPadding(const char* src, size_t len) {
DCHECK_GE(len, 2);
return src[len - 1] != kBase64Pad ? 0 : src[len - 2] != kBase64Pad ? 1 : 2;
size_t numPadding{0};
while (len > 0 && src[len - 1] == kBase64Pad) {
numPadding++;
len--;
}

return numPadding;
}

static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table);
Expand All @@ -122,8 +126,7 @@ class Base64 {
size_t src_len,
char* dst,
size_t dst_len,
const ReverseIndex& table,
bool include_pad);
const ReverseIndex& table);
};

} // namespace facebook::velox::encoding
4 changes: 4 additions & 0 deletions velox/common/encode/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
endif()

add_library(velox_encode Base64.cpp)
target_link_libraries(velox_encode PUBLIC Folly::folly)
89 changes: 89 additions & 0 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/encode/Base64.h"
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"

namespace facebook::velox::encoding {
class Base64Test : public ::testing::Test {};

TEST_F(Base64Test, fromBase64) {
EXPECT_EQ(
"Hello, World!",
Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ==")));
EXPECT_EQ(
"Base64 encoding is fun.",
Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=")));
EXPECT_EQ(
"Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ=")));
EXPECT_EQ(
"1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA==")));

// Check encoded strings without padding
EXPECT_EQ(
"Hello, World!",
Base64::decode(folly::StringPiece("SGVsbG8sIFdvcmxkIQ")));
EXPECT_EQ(
"Base64 encoding is fun.",
Base64::decode(folly::StringPiece("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4")));
EXPECT_EQ(
"Simple text", Base64::decode(folly::StringPiece("U2ltcGxlIHRleHQ")));
EXPECT_EQ("1234567890", Base64::decode(folly::StringPiece("MTIzNDU2Nzg5MA")));
}

TEST_F(Base64Test, calculateDecodedSizeProperSize) {
size_t encoded_size{0};

encoded_size = 20;
EXPECT_EQ(
13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size));
EXPECT_EQ(18, encoded_size);

encoded_size = 18;
EXPECT_EQ(
13, Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ", encoded_size));
EXPECT_EQ(18, encoded_size);

encoded_size = 21;
EXPECT_THROW(
Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size),
facebook::velox::encoding::Base64Exception);

encoded_size = 32;
EXPECT_EQ(
23,
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size));
EXPECT_EQ(31, encoded_size);

encoded_size = 31;
EXPECT_EQ(
23,
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size));
EXPECT_EQ(31, encoded_size);

encoded_size = 16;
EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size));
EXPECT_EQ(14, encoded_size);

encoded_size = 14;
EXPECT_EQ(10, Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size));
EXPECT_EQ(14, encoded_size);
}

} // namespace facebook::velox::encoding
23 changes: 23 additions & 0 deletions velox/common/encode/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

add_executable(velox_common_encode_test Base64Test.cpp)
add_test(velox_common_encode_test velox_common_encode_test)
target_link_libraries(
velox_common_encode_test
PUBLIC Folly::folly
PRIVATE velox_encode
velox_exception
gtest
gtest_main)
13 changes: 4 additions & 9 deletions velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,15 @@ struct ToBase64Function {
template <typename T>
struct FromBase64Function {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
try {
auto inputSize = input.size();
result.resize(
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decode(input.data(), input.size(), result.data());
encoding::Base64::decode(
input.data(), inputSize, result.data(), result.size());
} catch (const encoding::Base64Exception& e) {
VELOX_USER_FAIL(e.what());
}
Expand All @@ -302,19 +302,14 @@ struct FromBase64Function {
template <typename T>
struct FromBase64UrlFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
auto inputData = input.data();
auto inputSize = input.size();
bool hasPad =
inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad);
result.resize(
encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad));
hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad.
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decodeUrl(
inputData, inputSize, result.data(), result.size(), hasPad);
input.data(), inputSize, result.data(), result.size());
}
};

Expand Down
6 changes: 6 additions & 0 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ TEST_F(BinaryFunctionsTest, fromBase64) {
EXPECT_EQ(std::nullopt, fromBase64(std::nullopt));
EXPECT_EQ("", fromBase64(""));
EXPECT_EQ("a", fromBase64("YQ=="));
EXPECT_EQ("ab", fromBase64("YWI="));
EXPECT_EQ("abc", fromBase64("YWJj"));
EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ="));
EXPECT_EQ(
Expand All @@ -433,6 +434,11 @@ TEST_F(BinaryFunctionsTest, fromBase64) {

EXPECT_THROW(fromBase64("YQ="), VeloxUserError);
EXPECT_THROW(fromBase64("YQ==="), VeloxUserError);

// Check encoded strings without padding
EXPECT_EQ("a", fromBase64("YQ"));
EXPECT_EQ("ab", fromBase64("YWI"));
EXPECT_EQ("abcd", fromBase64("YWJjZA"));
}

TEST_F(BinaryFunctionsTest, fromBase64Url) {
Expand Down

0 comments on commit 8e1a45a

Please sign in to comment.