Skip to content

Commit

Permalink
Refactor Encode API
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Oct 4, 2024
1 parent 3852bd4 commit 818a4ae
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 49 deletions.
73 changes: 37 additions & 36 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,8 @@ std::string Base64::encodeImpl(
const T& input,
const Charset& charset,
bool includePadding) {
const size_t encodedSize{calculateEncodedSize(input.size(), includePadding)};
std::string encodedResult;
encodedResult.resize(encodedSize);
(void)encodeImpl(input, charset, includePadding, encodedResult.data());
(void)encodeImpl(input, charset, includePadding, encodedResult);
return encodedResult;
}

Expand All @@ -184,77 +182,80 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) {
}

// static
Status Base64::encode(const char* input, size_t inputSize, char* output) {
return encodeImpl(
folly::StringPiece(input, inputSize), kBase64Charset, true, output);
Status Base64::encode(std::string_view input, std::string& output) {
return encodeImpl(input, kBase64Charset, true, output);
}

// static
Status
Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) {
Base64::encodeUrl(std::string_view input, std::string& output) {
return encodeImpl(
folly::StringPiece(input, inputSize),
input,
kBase64UrlCharset,
true,
outputBuffer);
output);
}

// static
template <class T>
Status Base64::encodeImpl(
const T& input,
const Base64::Charset& charset,
const Charset& charset,
bool includePadding,
char* outputBuffer) {
std::string& output) {
auto inputSize = input.size();
if (inputSize == 0) {
output.clear();
return Status::OK();
}

auto outputPointer = outputBuffer;
// Calculate the output size and resize the string beforehand
size_t outputSize = calculateEncodedSize(inputSize, includePadding);
output.resize(outputSize); // Resize the output string to the required size

// Use a pointer to write into the pre-allocated buffer
auto outputPointer = output.data();
auto inputIterator = input.begin();

// For each group of 3 bytes (24 bits) in the input, split that into
// 4 groups of 6 bits and encode that using the supplied charset lookup
// Encode input in chunks of 3 bytes
for (; inputSize > 2; inputSize -= 3) {
uint32_t inputBlock = static_cast<uint8_t>(*inputIterator++) << 16;
inputBlock |= static_cast<uint8_t>(*inputIterator++) << 8;
inputBlock |= static_cast<uint8_t>(*inputIterator++);
uint32_t inputBlock = uint8_t(*inputIterator++) << 16;
inputBlock |= uint8_t(*inputIterator++) << 8;
inputBlock |= uint8_t(*inputIterator++);

*outputPointer++ = charset[(inputBlock >> 18) & 0x3f];
*outputPointer++ = charset[(inputBlock >> 12) & 0x3f];
*outputPointer++ = charset[(inputBlock >> 6) & 0x3f];
*outputPointer++ = charset[inputBlock & 0x3f];
}

// Handle remaining bytes (1 or 2 bytes)
if (inputSize > 0) {
// We have either 1 or 2 input bytes left. Encode this similar to the
// above (assuming 0 for all other bytes). Optionally append the '='
// character if it is requested.
uint32_t inputBlock = static_cast<uint8_t>(*inputIterator++) << 16;
uint32_t inputBlock = uint8_t(*inputIterator++) << 16;
*outputPointer++ = charset[(inputBlock >> 18) & 0x3f];
if (inputSize > 1) {
inputBlock |= static_cast<uint8_t>(*inputIterator) << 8;
inputBlock |= uint8_t(*inputIterator) << 8;
*outputPointer++ = charset[(inputBlock >> 12) & 0x3f];
*outputPointer++ = charset[(inputBlock >> 6) & 0x3f];
if (includePadding) {
*outputPointer = kPadding;
*outputPointer++ = kPadding;
}
} else {
*outputPointer++ = charset[(inputBlock >> 12) & 0x3f];
if (includePadding) {
*outputPointer++ = kPadding;
*outputPointer = kPadding;
*outputPointer++ = kPadding;
}
}
}

return Status::OK();
}


// static
std::string Base64::encode(folly::StringPiece text) {
return encodeImpl(text, kBase64Charset, true);
std::string Base64::encode(folly::StringPiece input) {
return encodeImpl(input, kBase64Charset, true);
}

// static
Expand Down Expand Up @@ -425,7 +426,7 @@ Status Base64::decodeImpl(

// Set up input and output pointers
const char* inputPtr = input.data();
char* outputPtr = output.data();
char* outputPointer = output.data();
Status lookupStatus;

// Process full blocks of 4 characters
Expand All @@ -440,28 +441,28 @@ Status Base64::decodeImpl(
return lookupStatus;
}

uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3;
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
outputPtr[2] = static_cast<char>(currentBlock & 0xFF);
uint32_t inputBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3;
outputPointer[0] = static_cast<char>((inputBlock >> 16) & 0xFF);
outputPointer[1] = static_cast<char>((inputBlock >> 8) & 0xFF);
outputPointer[2] = static_cast<char>(inputBlock & 0xFF);

inputPtr += 4;
outputPtr += 3;
outputPointer += 3;
}

// Handle remaining characters (2 or 3 characters at the end)
size_t remaining = inputSize % 4;
if (remaining > 1) {
uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus);
uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus);
uint32_t currentBlock = (val0 << 18) | (val1 << 12);
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);
uint32_t inputBlock = (val0 << 18) | (val1 << 12);
outputPointer[0] = static_cast<char>((inputBlock >> 16) & 0xFF);

if (remaining == 3) {
uint8_t val2 =
base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus);
currentBlock |= (val2 << 6);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
inputBlock |= (val2 << 6);
outputPointer[1] = static_cast<char>((inputBlock >> 8) & 0xFF);
}
}

Expand Down
15 changes: 7 additions & 8 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ class Base64 {
// Encoding Functions
/// Encodes the input data using Base64 encoding.
static std::string encode(const char* input, size_t inputSize);
static std::string encode(folly::StringPiece text);
static std::string encode(folly::StringPiece input);
static std::string encode(const folly::IOBuf* inputBuffer);
static Status encode(const char* input, size_t inputSize, char* outputBuffer);
static Status encode(std::string_view input, std::string& outputBuffer);

/// Encodes the input data using Base64 URL encoding.
static std::string encodeUrl(const char* input, size_t inputSize);
static std::string encodeUrl(folly::StringPiece text);
static std::string encodeUrl(const folly::IOBuf* inputBuffer);
static Status
encodeUrl(const char* input, size_t inputSize, char* outputBuffer);
encodeUrl(std::string_view input, std::string& output);

// Decoding Functions
/// Decodes the input Base64 encoded string.
Expand All @@ -69,10 +69,6 @@ class Base64 {
std::string& output);
static Status decodeUrl(std::string_view input, std::string& output);

// Helper Functions
/// Calculates the encoded size based on input size.
static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true);

private:
// Checks if the input Base64 string is padded.
static inline bool isPadded(std::string_view input) {
Expand Down Expand Up @@ -107,7 +103,7 @@ class Base64 {
const T& input,
const Charset& charset,
bool includePadding,
char* outputBuffer);
std::string& output);

static Status decodeImpl(
std::string_view input,
Expand All @@ -121,6 +117,9 @@ class Base64 {
size_t& inputSize,
size_t& decodedSize);

// Calculates the encoded size based on input size.
static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true);

VELOX_FRIEND_TEST(Base64Test, isPadded);
VELOX_FRIEND_TEST(Base64Test, numPadding);
};
Expand Down
23 changes: 18 additions & 5 deletions velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,15 @@ struct ToBase64Function {

FOLLY_ALWAYS_INLINE Status
call(out_type<Varchar>& result, const arg_type<Varbinary>& input) {
result.resize(encoding::Base64::calculateEncodedSize(input.size()));
return encoding::Base64::encode(input.data(), input.size(), result.data());
std::string_view inputView(input.data(), input.size());
std::string output;
auto status = encoding::Base64::encode(inputView, output);
if (!status.ok()) {
return status;
}
result.resize(output.size());
std::memcpy(result.data(), output.data(), output.size());
return Status::OK();
}
};

Expand Down Expand Up @@ -328,9 +335,15 @@ struct ToBase64UrlFunction {

FOLLY_ALWAYS_INLINE Status
call(out_type<Varchar>& result, const arg_type<Varbinary>& input) {
result.resize(encoding::Base64::calculateEncodedSize(input.size()));
return encoding::Base64::encodeUrl(
input.data(), input.size(), result.data());
std::string_view inputView(input.data(), input.size());
std::string output;
auto status = encoding::Base64::encodeUrl(inputView, output);
if (!status.ok()) {
return status;
}
result.resize(output.size());
std::memcpy(result.data(), output.data(), output.size());
return Status::OK();
}
};

Expand Down

0 comments on commit 818a4ae

Please sign in to comment.