Skip to content

Commit

Permalink
Rewrite base64 to match the latest coding guidelines
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Jul 9, 2024
1 parent 77589a9 commit 72f2e1e
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 97 deletions.
137 changes: 68 additions & 69 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <folly/io/Cursor.h>
#include <stdint.h>

#include "velox/common/base/Exceptions.h"

namespace facebook::velox::encoding {

// Constants defining the size in bytes of binary and encoded blocks for Base64
Expand All @@ -29,19 +31,22 @@ constexpr static int kBinaryBlockByteSize = 3;
// Size of an encoded block in bytes (4 bytes = 24 bits)
constexpr static int kEncodedBlockByteSize = 4;

// Character sets for Base64 and Base64 URL encoding
constexpr const Base64::Charset kBase64Charset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/'};

constexpr const Base64::Charset kBase64UrlCharset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-', '_'};

// Reverse lookup tables for decoding
constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
Expand All @@ -61,6 +66,7 @@ constexpr const Base64::ReverseIndex kBase64ReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
Expand All @@ -81,76 +87,64 @@ constexpr const Base64::ReverseIndex kBase64UrlReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

constexpr bool checkForwardIndex(
uint8_t idx,
const Base64::Charset& charset,
const Base64::ReverseIndex& table) {
return (table[static_cast<uint8_t>(charset[idx])] == idx) &&
(idx > 0 ? checkForwardIndex(idx - 1, charset, table) : true);
}
// Verify that for every entry in kBase64Charset, the corresponding entry
// in kBase64ReverseIndexTable is correct.
static_assert(
checkForwardIndex(
Base64::checkForwardIndex(
sizeof(kBase64Charset) - 1,
kBase64Charset,
kBase64ReverseIndexTable),
"kBase64Charset has incorrect entries");

// Verify that for every entry in kBase64UrlCharset, the corresponding entry
// in kBase64UrlReverseIndexTable is correct.
static_assert(
checkForwardIndex(
Base64::checkForwardIndex(
sizeof(kBase64UrlCharset) - 1,
kBase64UrlCharset,
kBase64UrlReverseIndexTable),
"kBase64UrlCharset has incorrect entries");
// Similar to strchr(), but for null-terminated const strings.
// Another difference is that we do not consider "\0" to be present in the
// string.
// Returns true if "str" contains the character c.
constexpr bool constCharsetContains(

// static
const bool Base64::findCharacterInCharSet(
const Base64::Charset& charset,
uint8_t idx,
const char c) {
return idx < charset.size() &&
((charset[idx] == c) || constCharsetContains(charset, idx + 1, c));
}
constexpr bool checkReverseIndex(
uint8_t idx,
const Base64::Charset& charset,
const Base64::ReverseIndex& table) {
return (table[idx] == 255
? !constCharsetContains(charset, 0, static_cast<char>(idx))
: (charset[table[idx]] == idx)) &&
(idx > 0 ? checkReverseIndex(idx - 1, charset, table) : true);
for (; idx < charset.size(); ++idx) {
if (charset[idx] == c) {
return true;
}
}
return false;
}

// Verify that for every entry in kBase64ReverseIndexTable, the corresponding
// entry in kBase64Charset is correct.
static_assert(
checkReverseIndex(
Base64::checkReverseIndex(
sizeof(kBase64ReverseIndexTable) - 1,
kBase64Charset,
kBase64ReverseIndexTable),
"kBase64ReverseIndexTable has incorrect entries.");

// Verify that for every entry in kBase64ReverseIndexTable, the corresponding
// entry in kBase64Charset is correct.
// We can't run this check as the URL version has two duplicate entries so that
// the url decoder can handle url encodings and default encodings
// static_assert(
// checkReverseIndex(
// sizeof(kBase64UrlReverseIndexTable) - 1,
// kBase64UrlCharset,
// kBase64UrlReverseIndexTable),
// "kBase64UrlReverseIndexTable has incorrect entries.");
static_assert(
Base64::checkReverseIndex(
sizeof(kBase64UrlReverseIndexTable) - 1,
kBase64UrlCharset,
kBase64UrlReverseIndexTable),
"kBase64UrlReverseIndexTable has incorrect entries.");

// Implementation of Base64 encoding and decoding functions.
template <class T>
/* static */ std::string
Base64::encodeImpl(const T& data, const Charset& charset, bool include_pad) {
/* static */ std::string Base64::encodeImpl(
const T& data,
const Base64::Charset& charset,
bool include_pad) {
size_t outlen = calculateEncodedSize(data.size(), include_pad);

std::string out;
out.resize(outlen);

encodeImpl(data, charset, include_pad, out.data());
return out;
}
Expand Down Expand Up @@ -183,7 +177,7 @@ void Base64::encodeUrl(const char* data, size_t len, char* output) {
template <class T>
/* static */ void Base64::encodeImpl(
const T& data,
const Charset& charset,
const Base64::Charset& charset,
bool include_pad,
char* out) {
auto len = data.size();
Expand Down Expand Up @@ -218,22 +212,24 @@ template <class T>
*wp++ = charset[(curr >> 12) & 0x3f];
*wp++ = charset[(curr >> 6) & 0x3f];
if (include_pad) {
*wp = kBase64Pad;
*wp = kPadding;
}
} else {
*wp++ = charset[(curr >> 12) & 0x3f];
if (include_pad) {
*wp++ = kBase64Pad;
*wp = kBase64Pad;
*wp++ = kPadding;
*wp = kPadding;
}
}
}
}

// static
std::string Base64::encode(folly::StringPiece text) {
return encodeImpl(text, kBase64Charset, true);
}

// static
std::string Base64::encode(const char* data, size_t len) {
return encode(folly::StringPiece(data, len));
}
Expand Down Expand Up @@ -284,24 +280,19 @@ class IOBufWrapper {

} // namespace

// static
std::string Base64::encode(const folly::IOBuf* data) {
return encodeImpl(IOBufWrapper(data), kBase64Charset, true);
}

void Base64::encodeAppend(folly::StringPiece text, std::string& out) {
size_t outlen = calculateEncodedSize(text.size(), true);

size_t initialLen = out.size();
out.resize(initialLen + outlen);
encodeImpl(text, kBase64Charset, true, out.data() + initialLen);
}

// static
std::string Base64::decode(folly::StringPiece encoded) {
std::string output;
Base64::decode(std::make_pair(encoded.data(), encoded.size()), output);
return output;
}

// static
void Base64::decode(
const std::pair<const char*, int32_t>& payload,
std::string& output) {
Expand All @@ -316,18 +307,18 @@ void Base64::decode(const char* data, size_t size, char* output) {
Base64::decode(data, size, output, out_len);
}

uint8_t Base64::Base64ReverseLookup(
// static
uint8_t Base64::base64ReverseLookup(
char p,
const Base64::ReverseIndex& reverse_lookup) {
auto curr = reverse_lookup[(uint8_t)p];
const Base64::ReverseIndex& reverseIndex) {
auto curr = reverseIndex[(uint8_t)p];
if (curr >= 0x40) {
throw Base64Exception(
"Base64::decode() - invalid input string: invalid characters");
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
}

return curr;
}

// static
size_t
Base64::decode(const char* src, size_t src_len, char* dst, size_t dst_len) {
return decodeImpl(src, src_len, dst, dst_len, kBase64ReverseIndexTable);
Expand All @@ -344,13 +335,13 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (size % kEncodedBlockByteSize != 0) {
throw Base64Exception(
VELOX_USER_FAIL(
"Base64::decode() - invalid input string: "
"string length is not a multiple of 4.");
}

auto needed = (size * kBinaryBlockByteSize) / kEncodedBlockByteSize;
auto padding = countPadding(data, size);
auto padding = numPadding(data, size);
size -= padding;

// Adjust the needed size by deducting the bytes corresponding to the
Expand All @@ -366,7 +357,7 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
// Adjust the needed size for extra bytes, if present
if (extra) {
if (extra == 1) {
throw Base64Exception(
VELOX_USER_FAIL(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
}
Expand All @@ -376,19 +367,20 @@ size_t Base64::calculateDecodedSize(const char* data, size_t& size) {
return needed;
}

// static
size_t Base64::decodeImpl(
const char* src,
size_t src_len,
char* dst,
size_t dst_len,
const ReverseIndex& reverse_lookup) {
const Base64::ReverseIndex& reverseIndex) {
if (!src_len) {
return 0;
}

auto needed = calculateDecodedSize(src, src_len);
if (dst_len < needed) {
throw Base64Exception(
VELOX_USER_FAIL(
"Base64::decode() - invalid output string: "
"output string is too small.");
}
Expand All @@ -398,10 +390,10 @@ size_t Base64::decodeImpl(
// Each character of the 4 encode 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8 bit bytes.
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12) |
(Base64ReverseLookup(src[2], reverse_lookup) << 6) |
Base64ReverseLookup(src[3], reverse_lookup);
uint32_t last = (Base64::base64ReverseLookup(src[0], reverseIndex) << 18) |
(Base64::base64ReverseLookup(src[1], reverseIndex) << 12) |
(Base64::base64ReverseLookup(src[2], reverseIndex) << 6) |
Base64::base64ReverseLookup(src[3], reverseIndex);
dst[0] = (last >> 16) & 0xff;
dst[1] = (last >> 8) & 0xff;
dst[2] = last & 0xff;
Expand All @@ -410,33 +402,37 @@ size_t Base64::decodeImpl(
// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(src_len >= 2);
uint32_t last = (Base64ReverseLookup(src[0], reverse_lookup) << 18) |
(Base64ReverseLookup(src[1], reverse_lookup) << 12);
uint32_t last = (Base64::base64ReverseLookup(src[0], reverseIndex) << 18) |
(Base64::base64ReverseLookup(src[1], reverseIndex) << 12);
dst[0] = (last >> 16) & 0xff;
if (src_len > 2) {
last |= Base64ReverseLookup(src[2], reverse_lookup) << 6;
last |= Base64::base64ReverseLookup(src[2], reverseIndex) << 6;
dst[1] = (last >> 8) & 0xff;
if (src_len > 3) {
last |= Base64ReverseLookup(src[3], reverse_lookup);
last |= Base64::base64ReverseLookup(src[3], reverseIndex);
dst[2] = last & 0xff;
}
}

return needed;
}

// static
std::string Base64::encodeUrl(folly::StringPiece text) {
return encodeImpl(text, kBase64UrlCharset, false);
}

// static
std::string Base64::encodeUrl(const char* data, size_t len) {
return encodeUrl(folly::StringPiece(data, len));
}

// static
std::string Base64::encodeUrl(const folly::IOBuf* data) {
return encodeImpl(IOBufWrapper(data), kBase64UrlCharset, false);
}

// static
void Base64::decodeUrl(
const char* src,
size_t src_len,
Expand All @@ -445,12 +441,14 @@ void Base64::decodeUrl(
decodeImpl(src, src_len, dst, dst_len, kBase64UrlReverseIndexTable);
}

// static
std::string Base64::decodeUrl(folly::StringPiece encoded) {
std::string output;
Base64::decodeUrl(std::make_pair(encoded.data(), encoded.size()), output);
return output;
}

// static
void Base64::decodeUrl(
const std::pair<const char*, int32_t>& payload,
std::string& output) {
Expand All @@ -464,4 +462,5 @@ void Base64::decodeUrl(
kBase64UrlReverseIndexTable);
output.resize(out_len);
}

} // namespace facebook::velox::encoding
Loading

0 comments on commit 72f2e1e

Please sign in to comment.