Skip to content

Commit

Permalink
Convert Base64 as a non-throwing API
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Sep 23, 2024
1 parent 936300e commit 70adfa6
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 105 deletions.
135 changes: 86 additions & 49 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include <folly/io/Cursor.h>
#include <stdint.h>

#include "velox/common/base/Exceptions.h"

namespace facebook::velox::encoding {

// Constants defining the size in bytes of binary and encoded blocks for Base64
Expand Down Expand Up @@ -162,9 +160,9 @@ std::string Base64::encodeImpl(
const T& input,
const Base64::Charset& charset,
bool includePadding) {
size_t outputLength = calculateEncodedSize(input.size(), includePadding);
size_t outputSize = calculateEncodedSize(input.size(), includePadding);
std::string output;
output.resize(outputLength);
output.resize(outputSize);
encodeImpl(input, charset, includePadding, output.data());
return output;
}
Expand Down Expand Up @@ -313,26 +311,30 @@ void Base64::decode(std::string_view input, std::string& output) {
size_t inputSize{input.size()};
size_t decodedSize;

calculateDecodedSize(input, inputSize, decodedSize);
(void)calculateDecodedSize(input, inputSize, decodedSize);
output.resize(decodedSize);
decode(input.data(), inputSize, output.data(), output.size());
(void)decode(input.data(), inputSize, output.data(), output.size());
}

// static
void Base64::decode(std::string_view input, size_t size, char* output) {
size_t outputLength = size / 4 * 3;
Base64::decode(input, size, output, outputLength);
void Base64::decode(std::string_view input, size_t inputSize, char* output) {
size_t outputSize;
(void)calculateDecodedSize(input, inputSize, outputSize);
(void)decode(input, inputSize, output, outputSize);
}

// static
uint8_t Base64::base64ReverseLookup(
char character,
const Base64::ReverseIndex& reverseIndex) {
auto lookupValue = reverseIndex[(uint8_t)character];
if (lookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
char p,
const Base64::ReverseIndex& reverseIndex,
Status& status) {
auto curr = reverseIndex[(uint8_t)p];
if (curr >= 0x40) {
status = Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters.");
return 0; // Return 0 or any other error code indicating failure
}
return lookupValue;
return curr;
}

// static
Expand Down Expand Up @@ -399,7 +401,7 @@ Status Base64::decodeImpl(
char* output,
size_t outputSize,
const Base64::ReverseIndex& reverseIndex) {
if (!inputSize) {
if (inputSize == 0) {
return Status::OK();
}

Expand All @@ -415,36 +417,44 @@ Status Base64::decodeImpl(
"Base64::decode() - invalid output string: output string is too small.");
}

// Handle full groups of 4 characters
for (; inputSize > 4; inputSize -= 4, input.remove_prefix(4), output += 3) {
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t currentBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
output[0] = (currentBlock >> 16) & 0xff;
output[1] = (currentBlock >> 8) & 0xff;
output[2] = currentBlock & 0xff;
const char* inputPtr = input.data();
char* outputPtr = output;
Status lookupStatus;

// Process full blocks of 4 characters
size_t fullBlockCount = inputSize / 4;
for (size_t i = 0; i < fullBlockCount; ++i) {
uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus);
uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus);
uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus);
uint8_t val3 = base64ReverseLookup(inputPtr[3], reverseIndex, lookupStatus);

uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3;
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
outputPtr[2] = static_cast<char>(currentBlock & 0xFF);

inputPtr += 4;
outputPtr += 3;
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t currentBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
output[0] = (currentBlock >> 16) & 0xff;
if (inputSize > 2) {
currentBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
output[1] = (currentBlock >> 8) & 0xff;
if (inputSize > 3) {
currentBlock |= base64ReverseLookup(input[3], reverseIndex);
output[2] = currentBlock & 0xff;
// Handle the last block (2-3 characters)
size_t remaining = inputSize % 4;
if (remaining > 1) {
uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus);
uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus);
uint32_t currentBlock = (val0 << 18) | (val1 << 12);
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);

if (remaining == 3) {
uint8_t val2 =
base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus);
currentBlock |= (val2 << 6);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
}
}

if (!lookupStatus.ok())
return lookupStatus;
return Status::OK();
}

Expand Down Expand Up @@ -477,15 +487,42 @@ std::string Base64::decodeUrl(std::string_view input) {

// static
void Base64::decodeUrl(std::string_view input, std::string& output) {
size_t out_len = (input.size() + 3) / 4 * 3;
output.resize(out_len, '\0');
Base64::decodeImpl(
// Early exit if input is empty
if (input.empty()) {
output.clear();
return;
}

size_t inputSize = input.size();
size_t outputSize;

// Calculate the size for the decoded output
auto status = calculateDecodedSize(input, inputSize, outputSize);
if (!status.ok()) {
// status is discarded here, but could be used to handle the error
output.clear();
return;
}

// Resize the output string to the calculated size
output.resize(outputSize);

// Perform the actual decoding
status = Base64::decodeImpl(
input.data(),
input.size(),
&output[0],
out_len,
inputSize,
output.data(),
outputSize,
kBase64UrlReverseIndexTable);
output.resize(out_len);

if (!status.ok()) {
// status is discarded here, but could be used to handle the error
output.clear();
return;
}

// Resize the output to match the actual size of the decoded data
output.resize(outputSize);
}

} // namespace facebook::velox::encoding
12 changes: 6 additions & 6 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,19 +121,18 @@ class Base64 {

// Counts the number of padding characters in encoded input.
static inline size_t numPadding(std::string_view input, size_t inputSize) {
size_t padding = 0;
size_t numPadding{0};
while (inputSize > 0 && input[inputSize - 1] == kPadding) {
padding++;
numPadding++;
inputSize--;
}
return padding;
return numPadding;
}

// Performs a reverse lookup in the reverse index to retrieve the original
// index of a character in the base.
static uint8_t base64ReverseLookup(
char character,
const ReverseIndex& reverseIndex);
static uint8_t
base64ReverseLookup(char p, const ReverseIndex& reverseIndex, Status& status);

// Encodes the specified input using the provided charset.
template <class T>
Expand All @@ -158,6 +157,7 @@ class Base64 {

VELOX_FRIEND_TEST(Base64Test, isPadded);
VELOX_FRIEND_TEST(Base64Test, numPadding);
VELOX_FRIEND_TEST(Base64Test, testDecodeImpl);
};

} // namespace facebook::velox::encoding
Loading

0 comments on commit 70adfa6

Please sign in to comment.