Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Aug 19, 2024
1 parent c57e4c9 commit a85d043
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 90 deletions.
86 changes: 48 additions & 38 deletions velox/common/encode/Base64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
#include <folly/io/Cursor.h>
#include <stdint.h>

#include "velox/common/base/Exceptions.h"

namespace facebook::velox::encoding {

// Constants defining the size in bytes of binary and encoded blocks for Base64
Expand Down Expand Up @@ -324,16 +322,6 @@ void Base64::decode(std::string_view input, size_t size, char* output) {
Base64::decode(input, size, output, outputLength);
}

// static
uint8_t Base64::base64ReverseLookup(
char character,
const Base64::ReverseIndex& reverseIndex) {
auto lookupValue = reverseIndex[(uint8_t)character];
if (lookupValue >= 0x40) {
VELOX_USER_FAIL("decode() - invalid input string: invalid characters");
}
return lookupValue;
}

// static
Status Base64::decode(
Expand Down Expand Up @@ -399,7 +387,7 @@ Status Base64::decodeImpl(
char* output,
size_t outputSize,
const Base64::ReverseIndex& reverseIndex) {
if (!inputSize) {
if (inputSize == 0) {
return Status::OK();
}

Expand All @@ -415,33 +403,55 @@ Status Base64::decodeImpl(
"Base64::decode() - invalid output string: output string is too small.");
}

// Handle full groups of 4 characters
for (; inputSize > 4; inputSize -= 4, input.remove_prefix(4), output += 3) {
// Each character of the 4 encodes 6 bits of the original, grab each with
// the appropriate shifts to rebuild the original and then split that back
// into the original 8-bit bytes.
uint32_t currentBlock =
(base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12) |
(base64ReverseLookup(input[2], reverseIndex) << 6) |
base64ReverseLookup(input[3], reverseIndex);
output[0] = (currentBlock >> 16) & 0xff;
output[1] = (currentBlock >> 8) & 0xff;
output[2] = currentBlock & 0xff;
const char* inputPtr = input.data();
char* outputPtr = output;

// Process full blocks of 4 characters
size_t fullBlockCount = inputSize / 4;
for (size_t i = 0; i < fullBlockCount; ++i) {
uint8_t val0 = reverseIndex[static_cast<uint8_t>(inputPtr[0])];
uint8_t val1 = reverseIndex[static_cast<uint8_t>(inputPtr[1])];
uint8_t val2 = reverseIndex[static_cast<uint8_t>(inputPtr[2])];
uint8_t val3 = reverseIndex[static_cast<uint8_t>(inputPtr[3])];

// Check for invalid values
if (val0 == 0xFF || val1 == 0xFF || val2 == 0xFF || val3 == 0xFF) {
return Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters.");
}

uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3;
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
outputPtr[2] = static_cast<char>(currentBlock & 0xFF);

inputPtr += 4;
outputPtr += 3;
}

// Handle the last 2-4 characters. This is similar to the above, but the
// last 2 characters may or may not exist.
DCHECK(inputSize >= 2);
uint32_t currentBlock = (base64ReverseLookup(input[0], reverseIndex) << 18) |
(base64ReverseLookup(input[1], reverseIndex) << 12);
output[0] = (currentBlock >> 16) & 0xff;
if (inputSize > 2) {
currentBlock |= base64ReverseLookup(input[2], reverseIndex) << 6;
output[1] = (currentBlock >> 8) & 0xff;
if (inputSize > 3) {
currentBlock |= base64ReverseLookup(input[3], reverseIndex);
output[2] = currentBlock & 0xff;
// Handle the last block (2-3 characters)
size_t remaining = inputSize % 4;
if (remaining > 1) {
uint8_t val0 = reverseIndex[static_cast<uint8_t>(inputPtr[0])];
uint8_t val1 = reverseIndex[static_cast<uint8_t>(inputPtr[1])];

// Check for invalid values
if (val0 == 0xFF || val1 == 0xFF) {
return Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters.");
}

uint32_t currentBlock = (val0 << 18) | (val1 << 12);
outputPtr[0] = static_cast<char>((currentBlock >> 16) & 0xFF);

if (remaining == 3) {
uint8_t val2 = reverseIndex[static_cast<uint8_t>(inputPtr[2])];
if (val2 == 0xFF) {
return Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters.");
}
currentBlock |= (val2 << 6);
outputPtr[1] = static_cast<char>((currentBlock >> 8) & 0xFF);
}
}

Expand Down
7 changes: 1 addition & 6 deletions velox/common/encode/Base64.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ class Base64 {
return padding;
}

// Performs a reverse lookup in the reverse index to retrieve the original
// index of a character in the base.
static uint8_t base64ReverseLookup(
char character,
const ReverseIndex& reverseIndex);

// Encodes the specified input using the provided charset.
template <class T>
static std::string
Expand All @@ -158,6 +152,7 @@ class Base64 {

VELOX_FRIEND_TEST(Base64Test, isPadded);
VELOX_FRIEND_TEST(Base64Test, numPadding);
VELOX_FRIEND_TEST(Base64Test, testDecodeImpl);
};

} // namespace facebook::velox::encoding
130 changes: 90 additions & 40 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,21 @@

namespace facebook::velox::encoding {

class Base64Test : public ::testing::Test {};
class Base64Test : public ::testing::Test {
protected:
void checkDecodedSize(
const std::string& encodedString,
size_t expectedEncodedSize,
size_t expectedDecodedSize) {
size_t encodedSize = expectedEncodedSize;
size_t decodedSize = 0;
EXPECT_EQ(
Status::OK(),
Base64::calculateDecodedSize(encodedString, encodedSize, decodedSize));
EXPECT_EQ(expectedEncodedSize, encodedSize);
EXPECT_EQ(expectedDecodedSize, decodedSize);
}
};

TEST_F(Base64Test, fromBase64) {
EXPECT_EQ("Hello, World!", Base64::decode("SGVsbG8sIFdvcmxkIQ=="));
Expand All @@ -43,49 +57,23 @@ TEST_F(Base64Test, fromBase64) {
}

TEST_F(Base64Test, calculateDecodedSizeProperSize) {
size_t encoded_size{0};
size_t decoded_size{0};

encoded_size = 20;
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ==", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 18;
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ", encoded_size, decoded_size);
EXPECT_EQ(18, encoded_size);
EXPECT_EQ(13, decoded_size);

encoded_size = 21;
checkDecodedSize("SGVsbG8sIFdvcmxkIQ==", 18, 13);
checkDecodedSize("SGVsbG8sIFdvcmxkIQ", 18, 13);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 31, 23);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 23);
checkDecodedSize("MTIzNDU2Nzg5MA==", 14, 10);
checkDecodedSize("MTIzNDU2Nzg5MA", 14, 10);
}

TEST_F(Base64Test, calculateDecodedSizeImproperSize) {
size_t encodedSize{21};
size_t decodedSize;

EXPECT_EQ(
Status::UserError(
"Base64::decode() - invalid input string: string length is not a multiple of 4."),
Base64::calculateDecodedSize(
"SGVsbG8sIFdvcmxkIQ===", encoded_size, decoded_size));

encoded_size = 32;
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 31;
Base64::calculateDecodedSize(
"QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", encoded_size, decoded_size);
EXPECT_EQ(31, encoded_size);
EXPECT_EQ(23, decoded_size);

encoded_size = 16;
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA==", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);

encoded_size = 14;
Base64::calculateDecodedSize("MTIzNDU2Nzg5MA", encoded_size, decoded_size);
EXPECT_EQ(14, encoded_size);
EXPECT_EQ(10, decoded_size);
"SGVsbG8sIFdvcmxkIQ===", encodedSize, decodedSize));
}

TEST_F(Base64Test, isPadded) {
Expand All @@ -98,4 +86,66 @@ TEST_F(Base64Test, numPadding) {
EXPECT_EQ(1, Base64::numPadding("ABC=", 4));
EXPECT_EQ(2, Base64::numPadding("AB==", 4));
}

TEST_F(Base64Test, testDecodeImpl) {
char output[100];

// Reverse lookup tables for decoding
constexpr const Base64::ReverseIndex reverseTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255,
255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

// Invalid Base64 input: invalid character `$`
std::string_view invalidInput1 = "SGVsbG8gd29ybGQ$";
EXPECT_EQ(
Base64::decodeImpl(
invalidInput1,
invalidInput1.size(),
output,
sizeof(output),
reverseTable),
Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters."));

// Invalid Base64 input: incorrect padding
std::string_view invalidInput2 = "====";
EXPECT_EQ(
Base64::decodeImpl(
invalidInput2,
invalidInput2.size(),
output,
sizeof(output),
reverseTable),
Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters."));

// Invalid Base64 input: incomplete encoding
std::string_view invalidInput4 = "S";
EXPECT_EQ(
Base64::decodeImpl(
invalidInput4,
2,
output,
sizeof(output),
reverseTable),
Status::UserError(
"Base64::decode() - invalid input string: contains invalid characters."));
}

} // namespace facebook::velox::encoding
7 changes: 1 addition & 6 deletions velox/common/encode/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,4 @@ add_executable(velox_common_encode_test Base64Test.cpp)
add_test(velox_common_encode_test velox_common_encode_test)
target_link_libraries(
velox_common_encode_test
PRIVATE
velox_encode
velox_status
velox_exception
GTest::gtest
GTest::gtest_main)
PRIVATE velox_encode velox_status GTest::gtest GTest::gtest_main)

0 comments on commit a85d043

Please sign in to comment.