Skip to content

Commit

Permalink
Clean up Base64
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Mar 25, 2024
1 parent 228a225 commit 8bb2ab9
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 95 deletions.
86 changes: 26 additions & 60 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <folly/Portability.h>
#include <folly/container/Foreach.h>
#include <folly/io/Cursor.h>
#include <stdint.h>

namespace facebook::velox::encoding {

Expand All @@ -28,20 +27,23 @@ constexpr static int kBinaryBlockSize = 3;
// Size of the encoded block after encoding.
constexpr static int kEncodedBlockSize = 4;

constexpr const Base64::Charset kBase64Charset = {
// Encoding base to be used.
constexpr static int kBase = 64;

constexpr const Charset kBase64Charset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};
constexpr const Base64::Charset kBase64UrlCharset = {
constexpr const Charset kBase64UrlCharset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'};

constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = {
constexpr const ReverseIndex kBase64ReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
Expand All @@ -60,7 +62,7 @@ constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};
constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
constexpr const ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
Expand All @@ -80,13 +82,6 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

constexpr bool checkForwardIndex(
uint8_t idx,
const Base64::Charset& charset,
const Base64::ReverseIndex& table) {
return (table[static_cast<uint8_t>(charset[idx])] == idx) &&
(idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true);
}
// Verify that for every entry in kBase64Charset, the corresponding entry
// in kBase64ReverseIndexTable is correct.
static_assert(
Expand All @@ -103,34 +98,17 @@ static_assert(
kBase64UrlCharset,
kBase64UrlReverseIndexTable),
"kBase64UrlCharset has incorrect entries");
// Similar to strchr(), but for null-terminated const strings.
// Another difference is that we do not consider "\0" to be present in the
// string.
// Returns true if "str" contains the character c.
constexpr bool constCharsetContains(
const Base64::Charset& charset,
uint8_t idx,
const char c) {
return idx < charset.size() &&
((charset[idx] == c) || constCharsetContains(charset, idx + 1, c));
}
constexpr bool checkReverseIndex(
uint8_t idx,
const Base64::Charset& charset,
const Base64::ReverseIndex& table) {
return (table[idx] == 255
? !constCharsetContains(charset, 0, static_cast<char>(idx))
: (charset[table[idx]] == idx)) &&
(idx > 0 ? checkReverseIndex(idx - 1, charset, table) : true);
}

// Verify that for every entry in kBase64ReverseIndexTable, the corresponding
// entry in kBase64Charset is correct.
static_assert(
checkReverseIndex(
sizeof(kBase64ReverseIndexTable) - 1,
kBase64Charset,
kBase,
kBase64ReverseIndexTable),
"kBase64ReverseIndexTable has incorrect entries.");

// Verify that for every entry in kBase64ReverseIndexTable, the corresponding
// entry in kBase64Charset is correct.
// We can't run this check as the URL version has two duplicate entries so that
Expand Down Expand Up @@ -217,13 +195,13 @@ template <class T>
*wp++ = charset[(curr >> 12) & 0x3f];
*wp++ = charset[(curr >> 6) & 0x3f];
if (include_pad) {
*wp = kBase64Pad;
*wp = kPadding;
}
} else {
*wp++ = charset[(curr >> 12) & 0x3f];
if (include_pad) {
*wp++ = kBase64Pad;
*wp = kBase64Pad;
*wp++ = kPadding;
*wp = kPadding;
}
}
}
Expand Down Expand Up @@ -315,18 +293,6 @@ void Base64::decode(const char* data, size_t size, char* output) {
Base64::decode(data, size, output, out_len);
}

uint8_t Base64::Base64ReverseLookup(
char p,
const Base64::ReverseIndex& reverse_lookup) {
auto curr = reverse_lookup[(uint8_t)p];
if (curr >= 0x40) {
throw Base64Exception(
"Base64::decode() - invalid input string: invalid characters");
}

return curr;
}

size_t
Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) {
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable);
Expand All @@ -338,12 +304,12 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
return 0;
}

// Check if the input data is padded
// Check if the input data is padded.
if (isPadded(data, size)) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (size % kEncodedBlockSize != 0) {
throw Base64Exception(
throw EncoderException(
"Base64::decode() - invalid input string: "
"string length is not a multiple of the encoded block size.");
}
Expand All @@ -352,7 +318,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
auto padding = countPadding(data, size);
size -= padding;

// Adjust the needed size for padding
// Adjust the needed size for padding.
return needed -
ceil((padding * kBinaryBlockSize) /
static_cast<double>(kEncodedBlockSize));
Expand All @@ -364,7 +330,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
// Adjust the needed size for extra bytes, if present
if (extra) {
if (extra == 1) {
throw Base64Exception(
throw EncoderException(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
Expand All @@ -386,7 +352,7 @@ size_t Base64::decodeImpl(

auto needed = calculateDecodedSize(src, src_len);
if (dst_len < needed) {
throw Base64Exception(
throw EncoderException(
"Base64::decode() - invalid output string: "
"output string is too small.");
}
Expand All @@ -396,10 +362,10 @@ size_t Base64::decodeImpl(
// Each character of the 4 encode 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8 bit bytes.
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12) |
(Base64ReverseLookup(src[2], reverse_lookup) << 6) |
Base64ReverseLookup(src[3], reverse_lookup);
uint32_t last = (baseReverseLookup(kBase,src[0], reverse_lookup) << 18) |
(baseReverseLookup(kBase,src[1], reverse_lookup) << 12) |
(baseReverseLookup(kBase,src[2], reverse_lookup) << 6) |
baseReverseLookup(kBase,src[3], reverse_lookup);
dst[0] = (last >> 16) & 0xff;
dst[1] = (last >> 8) & 0xff;
dst[2] = last & 0xff;
Expand All @@ -408,14 +374,14 @@ size_t Base64::decodeImpl(
// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(src_len >= 2);
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12);
uint32_t last = (baseReverseLookup(kBase,src[0], reverse_lookup) << 18) |
(baseReverseLookup(kBase,src[1], reverse_lookup) << 12);
dst[0] = (last >> 16) & 0xff;
if (src_len > 2) {
last |= Base64ReverseLookup(src[2], reverse_lookup) << 6;
last |= baseReverseLookup(kBase,src[2], reverse_lookup) << 6;
dst[1] = (last >> 8) & 0xff;
if (src_len > 3) {
last |= Base64ReverseLookup(src[3], reverse_lookup);
last |= baseReverseLookup(kBase,src[3], reverse_lookup);
dst[2] = last & 0xff;
}
}
Expand Down
35 changes: 2 additions & 33 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,17 @@

#include <folly/Range.h>
#include <folly/io/IOBuf.h>
#include "velox/common/encode/EncoderUtils.h"

namespace facebook::velox::encoding {

class Base64Exception : public std::exception {
public:
explicit Base64Exception(const char* msg) : msg_(msg) {}
const char* what() const noexcept override {
return msg_;
}

protected:
const char* msg_;
};

class Base64 {
public:
using Charset = std::array<char, 64>;
using ReverseIndex = std::array<uint8_t, 256>;

static std::string encode(const char* data, size_t len);
static std::string encode(folly::StringPiece text);
static std::string encode(const folly::IOBuf* text);

/// Returns encoded size for the input of the specified size.
// Returns encoded size for the input of the specified size.
static size_t calculateEncodedSize(size_t size, bool withPadding = true);

/// Encodes the specified number of characters from the 'data' and writes the
Expand Down Expand Up @@ -91,25 +78,7 @@ class Base64 {
static void
decodeUrl(const char* src, size_t src_len, char* dst, size_t dst_len);

constexpr static char kBase64Pad = '=';

private:
static inline bool isPadded(const char* data, size_t len) {
return (len > 0 && data[len - 1] == kBase64Pad) ? true : false;
}

static inline size_t countPadding(const char* src, size_t len) {
size_t numPadding{0};
while (len > 0 && src[len - 1] == kBase64Pad) {
numPadding++;
len--;
}

return numPadding;
}

static uint8_t Base64ReverseLookup(char p, const ReverseIndex& table);

template <class T>
static std::string
encodeImpl(const T& data, const Charset& charset, bool include_pad);
Expand Down
2 changes: 1 addition & 1 deletion velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ TEST_F(Base64Test, calculateDecodedSizeProperSize) {
encoded_size = 21;
EXPECT_THROW(
Base64::calculateDecodedSize("SGVsbG8sIFdvcmxkIQ==", encoded_size),
facebook::velox::encoding::Base64Exception);
facebook::velox::encoding::EncoderException);

encoded_size = 32;
EXPECT_EQ(
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ struct FromBase64Function {
encoding::Base64::calculateDecodedSize(input.data(), inputSize));
encoding::Base64::decode(
input.data(), inputSize, result.data(), result.size());
} catch (const encoding::Base64Exception& e) {
} catch (const encoding::EncoderException& e) {
VELOX_USER_FAIL(e.what());
}
}
Expand Down

0 comments on commit 8bb2ab9

Please sign in to comment.