Skip to content

Commit

Permalink
Fix padding issue in Base64
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Mar 11, 2024
1 parent a503f4f commit 9d33332
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 74 deletions.
98 changes: 43 additions & 55 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

constexpr bool checkForwardIndex(
uint8_t idx,
const Base64::Charset& charset,
const Base64::ReverseIndex& table) {
return (table[static_cast<uint8_t>(charset[idx])] == idx) &&
(idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true);
}
// Verify that for every entry in kBase64Charset, the corresponding entry
// in kBase64ReverseIndexTable is correct.
static_assert(
Expand Down Expand Up @@ -324,67 +317,64 @@ uint8_t Base64::Base64ReverseLookup(

size_t
Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) {
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable, true);
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable);
}

// static
size_t
Base64::calculateDecodedSize(const char* data, size_t& size, bool withPadding) {
size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
if (size == 0) {
return 0;
}

auto needed = (size / 4) * 3;
if (withPadding) {
// If the pad characters are included then the source string must be a
// multiple of 4 and we can query the end of the string to see how much
// padding exists.
if (size % 4 != 0) {
throw Base64Exception(
// Check if the input data is padded
if (isPadded(data, size)) {
/// If padded, ensure that the string length is a multiple of the encoded
/// block size
if (size % kEncodedBlockSize != 0) {
throw EncoderException(
"Base64::decode() - invalid input string: "
"string length is not multiple of 4.");
"string length is not a multiple of the encoded block size.");
}

auto needed = (size * kBinaryBlockSize) / kEncodedBlockSize;
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;
}

// If padding doesn't exist we need to calculate it from the size - if the
// size % 4 is 0 then we have an even multiple 3 byte chunks in the result
// if it is 2 then we need 1 more byte in the output. If it is 3 then we
// need 2 more bytes in the output. It should never be 1.
auto extra = size % 4;
if (extra) {
if (extra == 1) {
throw Base64Exception(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
// Adjust the needed size for padding
return needed -
ceil((padding * kBinaryBlockSize) /
static_cast<double>(kEncodedBlockSize));
} else {
// If not padded, Calculate extra bytes, if any
auto extra = size % kEncodedBlockSize;
auto needed = (size / kEncodedBlockSize) * kBinaryBlockSize;

// Adjust the needed size for extra bytes, if present
if (extra) {
if (extra == 1) {
throw EncoderException(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
needed += (extra * kBinaryBlockSize) / kEncodedBlockSize;
}
return needed + extra - 1;
}

// Just because we don't need the pad, doesn't mean it is not there. The
// URL decoder should be able to handle the original encoding.
auto padding = countPadding(data, size);
size -= padding;
return needed - padding;
return needed;
}
}

size_t Base64::decodeImpl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
const Base64::ReverseIndex& reverse_lookup,
bool include_pad) {
const ReverseIndex& reverse_lookup) {
if (!src_len) {
return 0;
}

auto needed = calculateDecodedSize(src, src_len, include_pad);
auto needed = calculateDecodedSize(src, src_len);
if (dst_len < needed) {
throw Base64Exception(
throw EncoderException(
"Base64::decode() - invalid output string: "
"output string is too small.");
}
Expand All @@ -394,10 +384,10 @@ size_t Base64::decodeImpl(
// Each character of the 4 encode 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8 bit bytes.
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12) |
(Base64ReverseLookup(src[2], reverse_lookup) << 6) |
Base64ReverseLookup(src[3], reverse_lookup);
uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) |
(baseReverseLookup(kBase, src[1], reverse_lookup) << 12) |
(baseReverseLookup(kBase, src[2], reverse_lookup) << 6) |
baseReverseLookup(kBase, src[3], reverse_lookup);
dst[0] = (last >> 16) & 0xff;
dst[1] = (last >> 8) & 0xff;
dst[2] = last & 0xff;
Expand All @@ -406,14 +396,14 @@ size_t Base64::decodeImpl(
// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(src_len >= 2);
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12);
uint32_t last = (baseReverseLookup(kBase, src[0], reverse_lookup) << 18) |
(baseReverseLookup(kBase, src[1], reverse_lookup) << 12);
dst[0] = (last >> 16) & 0xff;
if (src_len > 2) {
last |= Base64ReverseLookup(src[2], reverse_lookup) << 6;
last |= baseReverseLookup(kBase, src[2], reverse_lookup) << 6;
dst[1] = (last >> 8) & 0xff;
if (src_len > 3) {
last |= Base64ReverseLookup(src[3], reverse_lookup);
last |= baseReverseLookup(kBase, src[3], reverse_lookup);
dst[2] = last & 0xff;
}
}
Expand All @@ -437,9 +427,8 @@ void Base64::decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool hasPad) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable, hasPad);
size_t dst_len) {
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable);
}

std::string Base64::decodeUrl(folly::StringPiece encoded) {
Expand All @@ -458,8 +447,7 @@ void Base64::decodeUrl(
payload.second,
&output[0],
out_len,
kBase64UrlReverseIndexTable,
false);
kBase64UrlReverseIndexTable);
output.resize(out_len);
}
} // namespace facebook::velox::encoding
28 changes: 17 additions & 11 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <folly/Range.h>
#include <folly/io/IOBuf.h>
#include "velox/common/encode/EncoderUtils.h"

namespace facebook::velox::encoding {

Expand Down Expand Up @@ -59,8 +60,7 @@ class Base64 {

/// Returns decoded size for the specified input. Adjusts the 'size' to
/// subtract the length of the padding, if exists.
static size_t
calculateDecodedSize(const char* data, size_t& size, bool withPadding = true);
static size_t calculateDecodedSize(const char* data, size_t& size);

/// Decodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -69,7 +69,7 @@ class Base64 {

static void decode(
const std::pair<const char*, int32_t>& payload,
std::string& outp);
std::string& output);

/// Encodes the specified number of characters from the 'data' and writes the
/// result to the 'output'. The output must have enough space, e.g. as
Expand All @@ -89,12 +89,8 @@ class Base64 {
static size_t
decode(const char* src, size_t src_len, char* dst, size_t dst_len);

static void decodeUrl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
bool pad);
static void
decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len);

constexpr static char kBase64Pad = '=';

Expand Down Expand Up @@ -122,8 +118,18 @@ class Base64 {
size_t src_len,
char* dst,
size_t dst_len,
const ReverseIndex& table,
bool include_pad);
const ReverseIndex& table);

public:
// Encoding base to be used.
constexpr static int kBase = 64;

private:
// Size of the binary block before encoding.
constexpr static int kBinaryBlockSize = 3;

// Size of the encoded block after encoding.
constexpr static int kEncodedBlockSize = 4;
};

} // namespace facebook::velox::encoding
13 changes: 5 additions & 8 deletions velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,9 @@ struct FromBase64Function {
auto inputSize = input.size();
result.resize(
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decode(input.data(), input.size(), result.data());
} catch (const encoding::Base64Exception& e) {
encoding::Base64::decode(
input.data(), inputSize, result.data(), result.size());
} catch (const encoding::EncoderException& e) {
VELOX_USER_FAIL(e.what());
}
}
Expand All @@ -306,15 +307,11 @@ struct FromBase64UrlFunction {
FOLLY_ALWAYS_INLINE void call(
out_type<Varbinary>& result,
const arg_type<Varchar>& input) {
auto inputData = input.data();
auto inputSize = input.size();
bool hasPad =
inputSize > 0 && (*(input.end() - 1) == encoding::Base64::kBase64Pad);
result.resize(
encoding::Base64::calculateDecodedSize(inputData, inputSize, hasPad));
hasPad = false; // calculateDecodedSize() updated inputSize to exclude pad.
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decodeUrl(
inputData, inputSize, result.data(), result.size(), hasPad);
input.data(), inputSize, result.data(), result.size());
}
};

Expand Down
5 changes: 5 additions & 0 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,16 @@ TEST_F(BinaryFunctionsTest, fromBase64) {
EXPECT_EQ(std::nullopt, fromBase64(std::nullopt));
EXPECT_EQ("", fromBase64(""));
EXPECT_EQ("a", fromBase64("YQ=="));
EXPECT_EQ("ab", fromBase64("YWI="));
EXPECT_EQ("abc", fromBase64("YWJj"));
EXPECT_EQ("hello world", fromBase64("aGVsbG8gd29ybGQ="));
EXPECT_EQ(
"Hello World from Velox!",
fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE="));
// Check encoded strings without padding
EXPECT_EQ("a", fromBase64("YQ"));
EXPECT_EQ("ab", fromBase64("YWI"));
EXPECT_EQ("abcd", fromBase64("YWJjZA"));

EXPECT_THROW(fromBase64("YQ="), VeloxUserError);
EXPECT_THROW(fromBase64("YQ==="), VeloxUserError);
Expand Down

0 comments on commit 9d33332

Please sign in to comment.