Skip to content

Commit

Permalink
XXX
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Oct 5, 2024
1 parent cecd6ec commit b619d8b
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 13 deletions.
186 changes: 186 additions & 0 deletions velox/common/encode/Base32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/encode/Base32.h"

#include <glog/logging.h>

namespace facebook::velox::encoding {

// Constants defining the size in bytes of binary and encoded blocks for Base32
// encoding.
// Size of a binary block in bytes (5 bytes = 40 bits)
constexpr static int kBinaryBlockByteSize = 5;
// Size of an encoded block in bytes (8 bytes = 40 bits)
constexpr static int kEncodedBlockByteSize = 8;

constexpr Base32::Charset kBase32Charset = {
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K',
'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', '2', '3', '4', '5', '6', '7'};

constexpr Base32::ReverseIndex kBase32ReverseIndexTable = {
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 255, 255, 255, 255,
255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
25, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
255};

// Verify that for each 32 entries in kBase32Charset, the corresponding entry
// in kBase32ReverseIndexTable is correct.
static_assert(
checkForwardIndex(
sizeof(kBase32Charset) / 2 - 1,
kBase32Charset,
kBase32ReverseIndexTable),
"kBase32Charset has incorrect entries");

// Verify that for every entry in kBase32ReverseIndexTable, the corresponding
// entry in kBase32Charset is correct.
static_assert(
checkReverseIndex(
sizeof(kBase32ReverseIndexTable) - 1,
kBase32Charset,
kBase32ReverseIndexTable),
"kBase32ReverseIndexTable has incorrect entries.");

// static
uint8_t Base32::base32ReverseLookup(
char p,
const Base32::ReverseIndex& reverseIndex,
Status& status) {
return reverseLookup(p, reverseIndex, status, Base32::kCharsetSize);
}

// static
Status Base32::decode(std::string_view input, std::string& output) {
return decodeImpl(
input,
input.size(),
output.data(),
output.size(),
kBase32ReverseIndexTable);
}

// static
Status Base32::decodeImpl(
std::string_view input,
size_t inputSize,
char* output,
size_t outputSize,
const Base32::ReverseIndex& reverseIndex) {
// Check if input is empty
if (input.empty()) {
return Status::OK();
}

size_t decodedSize;
// Calculate decoded size and check for status
auto status = calculateDecodedSize(
input,
inputSize,
decodedSize,
kBinaryBlockByteSize,
kEncodedBlockByteSize);
if (!status.ok()) {
return status;
}

if (outputSize < decodedSize) {
return Status::UserError("Base32::decode() - output buffer too small.");
}

Status lookupStatus;
// Handle full groups of 8 characters.
while (inputSize >= 8) {
// Each character of the 8 bytes encodes 5 bits of the original, grab each
// with the appropriate shifts to rebuild the original and then split that
// back into the original 8-bit bytes.
uint64_t last =
(uint64_t(base32ReverseLookup(input[0], reverseIndex, lookupStatus))
<< 35) |
(uint64_t(base32ReverseLookup(input[1], reverseIndex, lookupStatus))
<< 30) |
(base32ReverseLookup(input[2], reverseIndex, lookupStatus) << 25) |
(base32ReverseLookup(input[3], reverseIndex, lookupStatus) << 20) |
(base32ReverseLookup(input[4], reverseIndex, lookupStatus) << 15) |
(base32ReverseLookup(input[5], reverseIndex, lookupStatus) << 10) |
(base32ReverseLookup(input[6], reverseIndex, lookupStatus) << 5) |
base32ReverseLookup(input[7], reverseIndex, lookupStatus);

output[0] = (last >> 32) & 0xff;
output[1] = (last >> 24) & 0xff;
output[2] = (last >> 16) & 0xff;
output[3] = (last >> 8) & 0xff;
output[4] = last & 0xff;

// Move the input string_view forward
input.remove_prefix(8);
output += 5;
inputSize -= 8;
}

// Handle the last 2, 4, 5, 7, or 8 characters.
if (inputSize >= 2) {
uint64_t last =
(uint64_t(base32ReverseLookup(input[0], reverseIndex, lookupStatus))
<< 35) |
(uint64_t(base32ReverseLookup(input[1], reverseIndex, lookupStatus))
<< 30);
output[0] = (last >> 32) & 0xff;

if (inputSize > 2) {
last |= base32ReverseLookup(input[2], reverseIndex, lookupStatus) << 25;
last |= base32ReverseLookup(input[3], reverseIndex, lookupStatus) << 20;
output[1] = (last >> 24) & 0xff;

if (inputSize > 4) {
last |= base32ReverseLookup(input[4], reverseIndex, lookupStatus) << 15;
output[2] = (last >> 16) & 0xff;

if (inputSize > 5) {
last |= base32ReverseLookup(input[5], reverseIndex, lookupStatus)
<< 10;
last |= base32ReverseLookup(input[6], reverseIndex, lookupStatus)
<< 5;
output[3] = (last >> 8) & 0xff;

if (inputSize > 7) {
last |= base32ReverseLookup(input[7], reverseIndex, lookupStatus);
output[4] = last & 0xff;
}
}
}
}
}

return lookupStatus.ok() ? Status::OK() : lookupStatus;
}

} // namespace facebook::velox::encoding

64 changes: 64 additions & 0 deletions velox/common/encode/Base32.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <exception>
#include <map>
#include <string>

#include "velox/common/base/GTestMacros.h"
#include "velox/common/base/Status.h"
#include "velox/common/encode/EncoderUtils.h"

namespace facebook::velox::encoding {

class Base32 {
public:
static const size_t kCharsetSize = 32;
static const size_t kReverseIndexSize = 256;

/// Character set used for encoding purposes.
/// Contains specific characters that form the encoding scheme.
using Charset = std::array<char, kCharsetSize>;

/// Reverse lookup table for decoding purposes.
/// Maps each possible encoded character to its corresponding numeric value
/// within the encoding base.
using ReverseIndex = std::array<uint8_t, kReverseIndexSize>;

/// Decodes the specified number of characters from the 'input' and writes the
/// result to the 'output'.
static Status decode(std::string_view input, std::string& output);

private:
// Performs a reverse lookup in the reverse index to retrieve the original
// index of a character in the base.
static uint8_t base32ReverseLookup(
char p,
const Base32::ReverseIndex& reverseIndex,
Status& status);

// Decodes the specified input using the provided reverse lookup table.
static Status decodeImpl(
std::string_view input,
size_t inputSize,
char* output,
size_t outputSize,
const Base32::ReverseIndex& reverseIndex);
};

} // namespace facebook::velox::encoding

2 changes: 1 addition & 1 deletion velox/common/encode/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ if(${VELOX_BUILD_TESTING})
add_subdirectory(tests)
endif()

velox_add_library(velox_encode Base64.cpp)
velox_add_library(velox_encode Base32.cpp Base64.cpp)
velox_link_libraries(velox_encode PUBLIC Folly::folly)
8 changes: 2 additions & 6 deletions velox/common/encode/EncoderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ static Status calculateDecodedSize(
// If padded, ensure that the string length is a multiple of the encoded
// block size
if (inputSize % encodedBlockByteSize != 0) {
return Status::UserError(
"decode() - invalid input string: "
"string length is not a multiple of 4.");
return Status::UserError("decode() - invalid input string length.");
}

decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize;
Expand All @@ -127,9 +125,7 @@ static Status calculateDecodedSize(
// Adjust the needed size for extra bytes, if present
if (extraBytes) {
if (extraBytes == 1) {
return Status::UserError(
"Base64::decode() - invalid input string: "
"string length cannot be 1 more than a multiple of 4.");
return Status::UserError("decode() - invalid input string length.");
}
decodedSize += (extraBytes * binaryBlockByteSize) / encodedBlockByteSize;
}
Expand Down
3 changes: 1 addition & 2 deletions velox/common/encode/tests/Base64Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ TEST_F(Base64Test, calculateDecodedSize) {
21,
0,
0,
Status::UserError(
"decode() - invalid input string: string length is not a multiple of 4."));
Status::UserError("decode() - invalid input string length."));
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 31, 23);
checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 31, 23);
checkDecodedSize("MTIzNDU2Nzg5MA==", 16, 14, 10);
Expand Down
21 changes: 21 additions & 0 deletions velox/functions/prestosql/BinaryFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "folly/ssl/OpenSSLHash.h"
#include "velox/common/base/BitUtil.h"
#include "velox/common/encode/Base32.h"
#include "velox/common/encode/Base64.h"
#include "velox/external/md5/md5.h"
#include "velox/functions/Udf.h"
Expand Down Expand Up @@ -347,6 +348,26 @@ struct ToBase64UrlFunction {
}
};

template <typename TExec>
struct FromBase32Function {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

// T can be either arg_type<Varchar> or arg_type<Varbinary>. These are the
// same, but hard-coding one of them might be confusing.
FOLLY_ALWAYS_INLINE Status
call(out_type<Varbinary>& result, const arg_type<Varchar>& input) {
std::string_view inputView(input.data(), input.size());
std::string output;
auto status = encoding::Base32::decode(inputView, output);
if (!status.ok()) {
return status;
}
result.resize(output.size());
std::memcpy(result.data(), output.data(), output.size());
return Status::OK();
}
};

template <typename T>
struct FromBigEndian32 {
VELOX_DEFINE_FUNCTION_TYPES(T);
Expand Down
6 changes: 2 additions & 4 deletions velox/functions/prestosql/tests/BinaryFunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,9 @@ TEST_F(BinaryFunctionsTest, fromBase64) {
fromBase64("SGVsbG8gV29ybGQgZnJvbSBWZWxveCE="));

VELOX_ASSERT_USER_THROW(
fromBase64("YQ="),
"Base64::decode() - invalid input string: string length is not a multiple of 4.");
fromBase64("YQ="), "decode() - invalid input string length.");
VELOX_ASSERT_USER_THROW(
fromBase64("YQ==="),
"Base64::decode() - invalid input string: string length is not a multiple of 4.");
fromBase64("YQ==="), "decode() - invalid input string length.");

// Check encoded strings without padding
EXPECT_EQ("a", fromBase64("YQ"));
Expand Down

0 comments on commit b619d8b

Please sign in to comment.