Skip to content

Commit

Permalink
Fix padding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Mar 11, 2024
1 parent 02ca9b0 commit 8543883
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 60 deletions.
71 changes: 33 additions & 38 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,65 +324,62 @@ uint8_t Base64::Base64ReverseLookup(

size_t
Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) {
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true);
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable);
}

// static
size_t
Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) {
size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
if (size == 0) {
return 0;
}

auto needed = (size / 4) * 3;
if (withPadding) {
// If the pad characters are included then the source string must be a
// multiple of 4 and we can query the end of the string to see how much
// padding exists.
if (size % 4 != 0) {
// Check if the input data is padded
if (isPadded(data, size)) {
/// If padded, ensure that the string length is a multiple of the encoded
/// block size
if (size % kEncodedBlockSize != 0) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length is not multiple of 4.");
"string length is not a multiple of the encoded block size.");
}

auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize;
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;
}

// If padding doesn't exist we need to calculate it from the size - if the
// size % 4 is 0 then we have an even multiple 3 byte chunks in the result
// if it is 2 then we need 1 more byte in the output. If it is 3 then we
// need 2 more bytes in the output. It should never be 1.
auto extra = size % 4;
if (extra) {
if (extra == 1) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
// Adjust the needed size for padding
return needed -
ceil((padding * kBinaryBlockSize) /
static_cast<double>(kEncodedBlockSize));
} else {
// If not padded, Calculate extra bytes, if any
auto extra = size % kEncodedBlockSize;
auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize;

// Adjust the needed size for extra bytes, if present
if (extra) {
if (extra == 1) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
needed += (extra * kBinaryBlockSize) / kEncodedBlockSize;
}
return needed + extra - 1;
}

// Just because we don't need the pad, doesn't mean it is not there. The
// URL decoder should be able to handle the original encoding.
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;
return needed;
}
}

size_t Base64::decodeImpl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
const Base64::ReverseIndex& reverse_lookup,
bool include_pad) {
const ReverseIndex& reverse_lookup) {
if (!src_len) {
return 0;
}

auto needed = calculateDecodedSize(src, src_len, include_pad);
auto needed = calculateDecodedSize(src, src_len);
if (dst_len < needed) {
throw Base64Exception(
"Base64::decode() - invalid output string: "
Expand Down Expand Up @@ -437,9 +434,8 @@ void Base64::decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool hasPad) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad);
size_t dst_len) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable);
}

std::string Base64::decodeUrl(folly::StringPiece encoded) {
Expand All @@ -458,8 +454,7 @@ void Base64::decodeUrl(
payload.second,
&output[0],
out_len,
kBase64UrlReverseIndexTable,
false);
kBase64UrlReverseIndexTable);
output.resize(out_len);
}
} // namespace facebook::velox::encoding
43 changes: 30 additions & 13 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ class Base64 {

/// Returns decoded size for the specified input. Adjusts the 'size' to
/// subtract the length of the padding, if exists.
static size_t
calculateDecodedSize(const char* data, size_t& size, bool withPadding = true);
static size_t calculateDecodedSize(const char* data, size_t& size);

/// Decodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -69,7 +68,7 @@ class Base64 {

static void decode(
const std::pair<const char*, int32_t>& payload,
std::string& outp);
std::string& output);

/// Encodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -89,19 +88,24 @@ class Base64 {
static size_t
decode(const char* src, size_t src_len, char* dst, size_t dst_len);

static void decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool pad);
static void
decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len);

constexpr static char kBase64Pad = '=';

private:
static inline bool isPadded(const char* data, size_t len) {
return (len > 0 && data[len - 1] == kPadding) ? true : false;
}

static inline size_t countPadding(const char* src, size_t len) {
DCHECK_GE(len, 2);
return src[len - 1] != kBase64Pad ? 0 : src[len - 2] != kBase64Pad ? 1 : 2;
size_t padding_count = 0;
while (len > 0 && src[len - 1] == kPadding) {
padding_count++;
len--;
}

return padding_count;
}

static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table);
Expand All @@ -122,8 +126,21 @@ class Base64 {
size_t src_len,
char* dst,
size_t dst_len,
const ReverseIndex& table,
bool include_pad);
const ReverseIndex& table);

public:
// Encoding base to be used.
constexpr static int kBase = 64;

private:
// Size of the binary block before encoding.
constexpr static int kBinaryBlockSize = 3;

// Padding character used in encoding
constexpr static char kPadding = '=';

// Size of the encoded block after encoding.
constexpr static int kEncodedBlockSize = 4;
};

} // namespace facebook::velox::encoding
13 changes: 4 additions & 9 deletions velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,15 @@ struct ToBase64Function {
template <typename T>
struct FromBase64Function {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
try {
auto inputSize = input.size();
result.resize(
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decode(input.data(), input.size(), result.data());
encoding::Base64::decode(
input.data(), inputSize, result.data(), result.size());
} catch (const encoding::Base64Exception& e) {
VELOX_USER_FAIL(e.what());
}
Expand All @@ -302,19 +302,14 @@ struct FromBase64Function {
template <typename T>
struct FromBase64UrlFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
auto inputData = input.data();
auto inputSize = input.size();
bool hasPad =
inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad);
result.resize(
encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad));
hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad.
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decodeUrl(
inputData, inputSize, result.data(), result.size(), hasPad);
input.data(), inputSize, result.data(), result.size());
}
};

Expand Down
6 changes: 6 additions & 0 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ TEST_F(BinaryFunctionsTest, fromBase64) {
EXPECT_EQ(std::nullopt, fromBase64(std::nullopt));
EXPECT_EQ("", fromBase64(""));
EXPECT_EQ("a", fromBase64("YQ=="));
EXPECT_EQ("ab", fromBase64("YWI="));
EXPECT_EQ("abc", fromBase64("YWJj"));
EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ="));
EXPECT_EQ(
Expand All @@ -432,6 +433,11 @@ TEST_F(BinaryFunctionsTest, fromBase64) {

EXPECT_THROW(fromBase64("YQ="), VeloxUserError);
EXPECT_THROW(fromBase64("YQ==="), VeloxUserError);

// Check encoded strings without padding
EXPECT_EQ("a", fromBase64("YQ"));
EXPECT_EQ("ab", fromBase64("YWI"));
EXPECT_EQ("abcd", fromBase64("YWJjZA"));
}

TEST_F(BinaryFunctionsTest, fromBase64Url) {
Expand Down

0 comments on commit 8543883

Please sign in to comment.