Skip to content

Commit

Permalink
from_base64
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Oct 5, 2024
1 parent b619d8b commit baf6c1b
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 85 deletions.
183 changes: 108 additions & 75 deletions velox/common/encode/Base32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,36 +71,32 @@ static_assert(

// static
uint8_t Base32::base32ReverseLookup(
char p,
const Base32::ReverseIndex& reverseIndex,
char encodedChar,
const ReverseIndex& reverseIndex,
Status& status) {
return reverseLookup(p, reverseIndex, status, Base32::kCharsetSize);
return reverseLookup(encodedChar, reverseIndex, status, kCharsetSize);
}

// static
Status Base32::decode(std::string_view input, std::string& output) {
return decodeImpl(
input,
input.size(),
output.data(),
output.size(),
kBase32ReverseIndexTable);
return decodeImpl(input, output, kBase32ReverseIndexTable);
}

// static
Status Base32::decodeImpl(
std::string_view input,
size_t inputSize,
char* output,
size_t outputSize,
const Base32::ReverseIndex& reverseIndex) {
// Check if input is empty
if (input.empty()) {
std::string& output,
const ReverseIndex& reverseIndex) {
size_t inputSize = input.size();

// If input is empty, clear output and return OK status.
if (inputSize == 0) {
output.clear();
return Status::OK();
}

// Calculate the decoded size based on the input size.
size_t decodedSize;
// Calculate decoded size and check for status
auto status = calculateDecodedSize(
input,
inputSize,
Expand All @@ -111,76 +107,113 @@ Status Base32::decodeImpl(
return status;
}

if (outputSize < decodedSize) {
return Status::UserError("Base32::decode() - output buffer too small.");
}
// Resize the output to accommodate the decoded data.
output.resize(decodedSize);

const char* inputPtr = input.data();
char* outputPtr = output.data();
Status lookupStatus;
// Handle full groups of 8 characters.
while (inputSize >= 8) {
// Each character of the 8 bytes encodes 5 bits of the original, grab each
// with the appropriate shifts to rebuild the original and then split that
// back into the original 8-bit bytes.
uint64_t last =
(uint64_t(base32ReverseLookup(input[0], reverseIndex, lookupStatus))
<< 35) |
(uint64_t(base32ReverseLookup(input[1], reverseIndex, lookupStatus))
<< 30) |
(base32ReverseLookup(input[2], reverseIndex, lookupStatus) << 25) |
(base32ReverseLookup(input[3], reverseIndex, lookupStatus) << 20) |
(base32ReverseLookup(input[4], reverseIndex, lookupStatus) << 15) |
(base32ReverseLookup(input[5], reverseIndex, lookupStatus) << 10) |
(base32ReverseLookup(input[6], reverseIndex, lookupStatus) << 5) |
base32ReverseLookup(input[7], reverseIndex, lookupStatus);

output[0] = (last >> 32) & 0xff;
output[1] = (last >> 24) & 0xff;
output[2] = (last >> 16) & 0xff;
output[3] = (last >> 8) & 0xff;
output[4] = last & 0xff;

// Move the input string_view forward
input.remove_prefix(8);
output += 5;
inputSize -= 8;

// Process full blocks of 8 characters
size_t fullBlockCount = inputSize / 8;
for (size_t i = 0; i < fullBlockCount; ++i) {
uint64_t inputBlock = 0;

// Decode 8 characters into a 40-bit block
for (int shift = 35, j = 0; j < 8; ++j, shift -= 5) {
uint64_t value =
base32ReverseLookup(inputPtr[j], reverseIndex, lookupStatus);
if (!lookupStatus.ok()) {
return lookupStatus;
}
inputBlock |= (value << shift);
}

// Write the decoded block to the output
outputPtr[0] = static_cast<char>((inputBlock >> 32) & 0xFF);
outputPtr[1] = static_cast<char>((inputBlock >> 24) & 0xFF);
outputPtr[2] = static_cast<char>((inputBlock >> 16) & 0xFF);
outputPtr[3] = static_cast<char>((inputBlock >> 8) & 0xFF);
outputPtr[4] = static_cast<char>(inputBlock & 0xFF);

inputPtr += 8;
outputPtr += 5;
}

// Handle the last 2, 4, 5, 7, or 8 characters.
if (inputSize >= 2) {
uint64_t last =
(uint64_t(base32ReverseLookup(input[0], reverseIndex, lookupStatus))
<< 35) |
(uint64_t(base32ReverseLookup(input[1], reverseIndex, lookupStatus))
// Handle remaining characters (2, 4, 5, 7)
size_t remaining = inputSize % 8;
if (remaining >= 2) {
uint64_t inputBlock = 0;

// Decode the first two characters
inputBlock |=
(static_cast<uint64_t>(
base32ReverseLookup(inputPtr[0], reverseIndex, lookupStatus))
<< 35);
inputBlock |=
(static_cast<uint64_t>(
base32ReverseLookup(inputPtr[1], reverseIndex, lookupStatus))
<< 30);
output[0] = (last >> 32) & 0xff;

if (inputSize > 2) {
last |= base32ReverseLookup(input[2], reverseIndex, lookupStatus) << 25;
last |= base32ReverseLookup(input[3], reverseIndex, lookupStatus) << 20;
output[1] = (last >> 24) & 0xff;

if (inputSize > 4) {
last |= base32ReverseLookup(input[4], reverseIndex, lookupStatus) << 15;
output[2] = (last >> 16) & 0xff;

if (inputSize > 5) {
last |= base32ReverseLookup(input[5], reverseIndex, lookupStatus)
<< 10;
last |= base32ReverseLookup(input[6], reverseIndex, lookupStatus)
<< 5;
output[3] = (last >> 8) & 0xff;

if (inputSize > 7) {
last |= base32ReverseLookup(input[7], reverseIndex, lookupStatus);
output[4] = last & 0xff;

if (!lookupStatus.ok()) {
return lookupStatus;
}
outputPtr[0] = static_cast<char>((inputBlock >> 32) & 0xFF);

if (remaining > 2) {
// Decode the next two characters
inputBlock |=
(base32ReverseLookup(inputPtr[2], reverseIndex, lookupStatus) << 25);
inputBlock |=
(base32ReverseLookup(inputPtr[3], reverseIndex, lookupStatus) << 20);

if (!lookupStatus.ok()) {
return lookupStatus;
}
outputPtr[1] = static_cast<char>((inputBlock >> 24) & 0xFF);

if (remaining > 4) {
// Decode the next character
inputBlock |=
(base32ReverseLookup(inputPtr[4], reverseIndex, lookupStatus)
<< 15);

if (!lookupStatus.ok()) {
return lookupStatus;
}
outputPtr[2] = static_cast<char>((inputBlock >> 16) & 0xFF);

if (remaining > 5) {
// Decode the next two characters
inputBlock |=
(base32ReverseLookup(inputPtr[5], reverseIndex, lookupStatus)
<< 10);
inputBlock |=
(base32ReverseLookup(inputPtr[6], reverseIndex, lookupStatus)
<< 5);

if (!lookupStatus.ok()) {
return lookupStatus;
}
outputPtr[3] = static_cast<char>((inputBlock >> 8) & 0xFF);

if (remaining > 7) {
// Decode the last character
inputBlock |=
base32ReverseLookup(inputPtr[7], reverseIndex, lookupStatus);

if (!lookupStatus.ok()) {
return lookupStatus;
}
outputPtr[4] = static_cast<char>(inputBlock & 0xFF);
}
}
}
}
}

return lookupStatus.ok() ? Status::OK() : lookupStatus;
// Return status
return Status::OK();
}

} // namespace facebook::velox::encoding

14 changes: 4 additions & 10 deletions velox/common/encode/Base32.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
*/
#pragma once

#include <exception>
#include <map>
#include <string>

#include "velox/common/base/GTestMacros.h"
#include "velox/common/base/Status.h"
#include "velox/common/encode/EncoderUtils.h"

Expand Down Expand Up @@ -47,18 +44,15 @@ class Base32 {
// Performs a reverse lookup in the reverse index to retrieve the original
// index of a character in the base.
static uint8_t base32ReverseLookup(
char p,
const Base32::ReverseIndex& reverseIndex,
char encodedChar,
const ReverseIndex& reverseIndex,
Status& status);

// Decodes the specified input using the provided reverse lookup table.
static Status decodeImpl(
std::string_view input,
size_t inputSize,
char* output,
size_t outputSize,
const Base32::ReverseIndex& reverseIndex);
std::string& output,
const ReverseIndex& reverseIndex);
};

} // namespace facebook::velox::encoding

23 changes: 23 additions & 0 deletions velox/docs/functions/presto/binary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ Binary Functions

Decodes ``string`` data from the base64 encoded representation using the `URL safe alphabet <https://www.rfc-editor.org/rfc/rfc4648#section-5>`_ into a varbinary.

.. function:: from_base64(string) -> varbinary

Decodes a Base64-encoded ``string`` back into its original binary form.
This function can handle both padded and non-padded Base64 encoded strings.
Partially padded Base64 strings will result in a "UserError" status being returned.

Examples
--------
Query with padded Base64 string:
::
SELECT from_base64('SGVsbG8gV29ybGQ='); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100]

Query with non-padded Base64 string:
::
SELECT from_base64('SGVsbG8gV29ybGQ'); -- [72, 101, 108, 108, 111, 32, 87, 111, 114, 108, 100]

Query with partially padded Base64 string:
::
SELECT from_base64('SGVsbG8gV29ybGQgZm9yIHZlbG94IQ='); -- Error: Base64::decode() - invalid input string: length is not a multiple of 4.

In the examples above, both fully padded and non-padded Base64 strings ('SGVsbG8gV29ybGQ=' and 'SGVsbG8gV29ybGQ') decode to the binary representation of the text 'Hello World'.
The partially padded Base64 string 'SGVsbG8gV29ybGQgZm9yIHZlbG94IQ=' will result in a "UserError" status indicating the Base64 string is invalid.

.. function:: from_big_endian_32(varbinary) -> integer

Decodes ``integer`` value from a 32-bit 2’s complement big endian ``binary``.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ void registerSimpleFunctions(const std::string& prefix) {
registerFunction<FromBase64UrlFunction, Varbinary, Varchar>(
{prefix + "from_base64url"});

registerFunction<FromBase32Function, Varbinary, Varchar>(
{prefix + "from_base32"});
registerFunction<FromBase32Function, Varbinary, Varbinary>(
{prefix + "from_base32"});

registerFunction<FromBigEndian32, int32_t, Varbinary>(
{prefix + "from_big_endian_32"});
registerFunction<ToBigEndian32, Varbinary, int32_t>(
Expand Down
51 changes: 51 additions & 0 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,57 @@ TEST_F(BinaryFunctionsTest, fromBase64Url) {
EXPECT_THROW(fromBase64Url("YQ=/"), VeloxUserError);
}

TEST_F(BinaryFunctionsTest, fromBase32) {
const auto fromBase32 = [&](std::optional<std::string> value) {
// from_base32 allows VARCHAR and VARBINARY inputs.
auto result =
evaluateOnce<std::string>("from_base32(c0)", VARCHAR(), value);
auto otherResult =
evaluateOnce<std::string>("from_base32(c0)", VARBINARY(), value);

VELOX_CHECK_EQ(result.has_value(), otherResult.has_value());

if (!result.has_value()) {
return result;
}

VELOX_CHECK_EQ(result.value(), otherResult.value());
return result;
};

EXPECT_EQ(std::nullopt, fromBase32(std::nullopt));
EXPECT_EQ("", fromBase32(""));
EXPECT_EQ("a", fromBase32("ME======"));
EXPECT_EQ("ab", fromBase32("MFRA===="));
EXPECT_EQ("abc", fromBase32("MFRGG==="));
EXPECT_EQ("db2", fromBase32("MRRDE==="));
EXPECT_EQ("abcd", fromBase32("MFRGGZA="));
EXPECT_EQ("hello world", fromBase32("NBSWY3DPEB3W64TMMQ======"));
EXPECT_EQ(
"Hello World from Velox!",
fromBase32("JBSWY3DPEBLW64TMMQQGM4TPNUQFMZLMN54CC==="));

// Try encoded strings without padding
EXPECT_EQ("a", fromBase32("ME"));
EXPECT_EQ("ab", fromBase32("MFRA"));
EXPECT_EQ("abc", fromBase32("MFRGG"));
EXPECT_EQ("db2", fromBase32("MRRDE"));
EXPECT_EQ("abcd", fromBase32("MFRGGZA"));
EXPECT_EQ("1234", fromBase32("GEZDGNA"));
EXPECT_EQ("abcde", fromBase32("MFRGGZDF"));
EXPECT_EQ("abcdef", fromBase32("MFRGGZDFMY"));

VELOX_ASSERT_USER_THROW(
fromBase32("1="),
"decode() - invalid input string length.");
VELOX_ASSERT_USER_THROW(
fromBase32("M1======"),
"invalid input string: contains invalid characters.");
VELOX_ASSERT_USER_THROW(
fromBase32("J$======"),
"invalid input string: contains invalid characters.");
}

TEST_F(BinaryFunctionsTest, fromBigEndian32) {
const auto fromBigEndian32 = [&](const std::optional<std::string>& arg) {
return evaluateOnce<int32_t>("from_big_endian_32(c0)", VARBINARY(), arg);
Expand Down

0 comments on commit baf6c1b

Please sign in to comment.