From cecd6ec47d3088e641c20725a5405aae06e83021 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Wed, 7 Aug 2024 18:44:55 +0530 Subject: [PATCH] Introduce utility class and refactor --- velox/common/encode/Base64.cpp | 80 +-------- velox/common/encode/Base64.h | 11 +- velox/common/encode/EncoderUtils.h | 167 ++++++++++++++++++ velox/common/encode/tests/Base64Test.cpp | 4 +- velox/common/encode/tests/CMakeLists.txt | 2 +- .../common/encode/tests/EncoderUtilsTests.cpp | 35 ++++ 6 files changed, 215 insertions(+), 84 deletions(-) create mode 100644 velox/common/encode/EncoderUtils.h create mode 100644 velox/common/encode/tests/EncoderUtilsTests.cpp diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index f5a0a8c54d37..7bf72480118b 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -166,21 +166,6 @@ std::string Base64::encodeImpl( return encodedResult; } -// static -size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) { - if (inputSize == 0) { - return 0; - } - - // Calculate the output size assuming that we are including padding. - size_t encodedSize = ((inputSize + 2) / 3) * 4; - if (!includePadding) { - // If the padding was not requested, subtract the padding bytes. - encodedSize -= (3 - (inputSize % 3)) % 3; - } - return encodedSize; -} - // static Status Base64::encode(std::string_view input, std::string& output) { return encodeImpl(input, kBase64Charset, true, output); @@ -205,7 +190,8 @@ Status Base64::encodeImpl( } // Calculate the output size and resize the string beforehand - size_t outputSize = calculateEncodedSize(inputSize, includePadding); + size_t outputSize = calculateEncodedSize( + inputSize, includePadding, kBinaryBlockByteSize, kEncodedBlockByteSize); output.resize(outputSize); // Resize the output string to the required size // Use a pointer to write into the pre-allocated buffer @@ -337,12 +323,7 @@ uint8_t Base64::base64ReverseLookup( char encodedChar, const ReverseIndex& reverseIndex, Status& status) { - auto reverseLookupValue = reverseIndex[static_cast(encodedChar)]; - if (reverseLookupValue >= 0x40) { - status = Status::UserError( - "decode() - invalid input string: invalid characters"); - } - return reverseLookupValue; + return reverseLookup(encodedChar, reverseIndex, status, kCharsetSize); } // static @@ -350,54 +331,6 @@ Status Base64::decode(std::string_view input, std::string& output) { return decodeImpl(input, output, kBase64ReverseIndexTable); } -// static -Status Base64::calculateDecodedSize( - std::string_view input, - size_t& inputSize, - size_t& decodedSize) { - if (inputSize == 0) { - decodedSize = 0; - return Status::OK(); - } - - // Check if the input string is padded - if (isPadded(input)) { - // If padded, ensure that the string length is a multiple of the encoded - // block size - if (inputSize % kEncodedBlockByteSize != 0) { - return Status::UserError( - "Base64::decode() - invalid input string: " - "string length is not a multiple of 4."); - } - - decodedSize = (inputSize * kBinaryBlockByteSize) / kEncodedBlockByteSize; - auto paddingCount = numPadding(input); - inputSize -= paddingCount; - - // Adjust the needed size by deducting the bytes corresponding to the - // padding from the calculated size. - decodedSize -= - ((paddingCount * kBinaryBlockByteSize) + (kEncodedBlockByteSize - 1)) / - kEncodedBlockByteSize; - return Status::OK(); - } - // If not padded, calculate extra bytes, if any - auto extraBytes = inputSize % kEncodedBlockByteSize; - decodedSize = (inputSize / kEncodedBlockByteSize) * kBinaryBlockByteSize; - - // Adjust the needed size for extra bytes, if present - if (extraBytes) { - if (extraBytes == 1) { - return Status::UserError( - "Base64::decode() - invalid input string: " - "string length cannot be 1 more than a multiple of 4."); - } - decodedSize += (extraBytes * kBinaryBlockByteSize) / kEncodedBlockByteSize; - } - - return Status::OK(); -} - // static Status Base64::decodeImpl( std::string_view input, @@ -411,7 +344,12 @@ Status Base64::decodeImpl( // Calculate the decoded size based on the input size size_t decodedSize; - auto status = calculateDecodedSize(input, inputSize, decodedSize); + auto status = calculateDecodedSize( + input, + inputSize, + decodedSize, + kBinaryBlockByteSize, + kEncodedBlockByteSize); if (!status.ok()) { return status; } diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index c45e745c8e8b..a9c515ee6078 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -22,6 +22,7 @@ #include #include "velox/common/base/GTestMacros.h" #include "velox/common/base/Status.h" +#include "velox/common/encode/EncoderUtils.h" namespace facebook::velox::encoding { @@ -109,16 +110,6 @@ class Base64 { std::string& output, const ReverseIndex& reverseIndex); - // Returns the actual size of the decoded data. Will also remove the padding - // length from the 'inputSize'. - static Status calculateDecodedSize( - std::string_view input, - size_t& inputSize, - size_t& decodedSize); - - // Calculates the encoded size based on input size. - static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); - VELOX_FRIEND_TEST(Base64Test, isPadded); VELOX_FRIEND_TEST(Base64Test, numPadding); VELOX_FRIEND_TEST(Base64Test, calculateDecodedSize); diff --git a/velox/common/encode/EncoderUtils.h b/velox/common/encode/EncoderUtils.h new file mode 100644 index 000000000000..7c5a8a5b09e5 --- /dev/null +++ b/velox/common/encode/EncoderUtils.h @@ -0,0 +1,167 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/base/Status.h" + +namespace facebook::velox::encoding { + +/// Padding character used in encoding. +const static char kPadding = '='; + +// Checks if the input Base64 string is padded. +static inline bool isPadded(std::string_view input) { + size_t inputSize{input.size()}; + return (inputSize > 0 && input[inputSize - 1] == kPadding); +} + +// Counts the number of padding characters in encoded input. +static inline size_t numPadding(std::string_view input) { + size_t numPadding{0}; + size_t inputSize{input.size()}; + while (inputSize > 0 && input[inputSize - 1] == kPadding) { + numPadding++; + inputSize--; + } + return numPadding; +} + +// Validate the character in charset with ReverseIndex table +template +constexpr bool checkForwardIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[static_cast(charset[index])] == index) && + (index > 0 ? checkForwardIndex(index - 1, charset, reverseIndex) : true); +} + +// Searches for a character within a charset up to a certain index. +template +constexpr bool findCharacterInCharset( + const Charset& charset, + uint8_t index, + const char targetChar) { + return index < charset.size() && + ((charset[index] == targetChar) || + findCharacterInCharset(charset, index + 1, targetChar)); +} + +// Checks the consistency of a reverse index mapping for a given character set. +template +constexpr bool checkReverseIndex( + uint8_t index, + const Charset& charset, + const ReverseIndex& reverseIndex) { + return (reverseIndex[index] == 255 + ? !findCharacterInCharset(charset, 0, static_cast(index)) + : (charset[reverseIndex[index]] == index)) && + (index > 0 ? checkReverseIndex(index - 1, charset, reverseIndex) : true); +} + +template +uint8_t reverseLookup( + char encodedChar, + const ReverseIndexType& reverseIndex, + Status& status, + uint8_t kBase) { + auto curr = reverseIndex[static_cast(encodedChar)]; + if (curr >= kBase) { + status = + Status::UserError("invalid input string: contains invalid characters."); + return 0; // Return 0 or any other error code indicating failure + } + return curr; +} + +// Returns the actual size of the decoded data. Will also remove the padding +// length from the 'inputSize'. +static Status calculateDecodedSize( + std::string_view input, + size_t& inputSize, + size_t& decodedSize, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + decodedSize = 0; + return Status::OK(); + } + + // Check if the input string is padded + if (isPadded(input)) { + // If padded, ensure that the string length is a multiple of the encoded + // block size + if (inputSize % encodedBlockByteSize != 0) { + return Status::UserError( + "decode() - invalid input string: " + "string length is not a multiple of 4."); + } + + decodedSize = (inputSize * binaryBlockByteSize) / encodedBlockByteSize; + auto paddingCount = numPadding(input); + inputSize -= paddingCount; + + // Adjust the needed size by deducting the bytes corresponding to the + // padding from the calculated size. + decodedSize -= + ((paddingCount * binaryBlockByteSize) + (encodedBlockByteSize - 1)) / + encodedBlockByteSize; + } else { + // If not padded, calculate extra bytes, if any + auto extraBytes = inputSize % encodedBlockByteSize; + decodedSize = (inputSize / encodedBlockByteSize) * binaryBlockByteSize; + // Adjust the needed size for extra bytes, if present + if (extraBytes) { + if (extraBytes == 1) { + return Status::UserError( + "Base64::decode() - invalid input string: " + "string length cannot be 1 more than a multiple of 4."); + } + decodedSize += (extraBytes * binaryBlockByteSize) / encodedBlockByteSize; + } + } + + return Status::OK(); +} + +// Calculates the encoded size based on input size. +static size_t calculateEncodedSize( + size_t inputSize, + bool includePadding, + const int binaryBlockByteSize, + const int encodedBlockByteSize) { + if (inputSize == 0) { + return 0; + } + + // Calculate the output size assuming that we are including padding. + size_t encodedSize = + ((inputSize + binaryBlockByteSize - 1) / binaryBlockByteSize) * + encodedBlockByteSize; + + if (!includePadding) { + // If the padding was not requested, subtract the padding bytes. + size_t remainder = inputSize % binaryBlockByteSize; + if (remainder != 0) { + encodedSize -= (binaryBlockByteSize - remainder); + } + } + + return encodedSize; +} + +} // namespace facebook::velox::encoding diff --git a/velox/common/encode/tests/Base64Test.cpp b/velox/common/encode/tests/Base64Test.cpp index 41f173b7d25c..ed0bbb7e693c 100644 --- a/velox/common/encode/tests/Base64Test.cpp +++ b/velox/common/encode/tests/Base64Test.cpp @@ -55,7 +55,7 @@ TEST_F(Base64Test, calculateDecodedSize) { size_t encoded_size = initialEncodedSize; size_t decoded_size = 0; Status status = - Base64::calculateDecodedSize(encodedString, encoded_size, decoded_size); + calculateDecodedSize(encodedString, encoded_size, decoded_size, 3, 4); if (expectedStatus.ok()) { EXPECT_EQ(Status::OK(), status); @@ -75,7 +75,7 @@ TEST_F(Base64Test, calculateDecodedSize) { 0, 0, Status::UserError( - "Base64::decode() - invalid input string: string length is not a multiple of 4.")); + "decode() - invalid input string: string length is not a multiple of 4.")); checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4=", 32, 31, 23); checkDecodedSize("QmFzZTY0IGVuY29kaW5nIGlzIGZ1bi4", 31, 31, 23); checkDecodedSize("MTIzNDU2Nzg5MA==", 16, 14, 10); diff --git a/velox/common/encode/tests/CMakeLists.txt b/velox/common/encode/tests/CMakeLists.txt index 63f718c24745..663b2413557a 100644 --- a/velox/common/encode/tests/CMakeLists.txt +++ b/velox/common/encode/tests/CMakeLists.txt @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -add_executable(velox_common_encode_test Base64Test.cpp) +add_executable(velox_common_encode_test Base64Test.cpp EncoderUtilsTests.cpp) add_test(velox_common_encode_test velox_common_encode_test) target_link_libraries( velox_common_encode_test diff --git a/velox/common/encode/tests/EncoderUtilsTests.cpp b/velox/common/encode/tests/EncoderUtilsTests.cpp new file mode 100644 index 000000000000..e112f8125349 --- /dev/null +++ b/velox/common/encode/tests/EncoderUtilsTests.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/common/encode/EncoderUtils.h" + +namespace facebook::velox::encoding { +class EncoderUtilsTest : public ::testing::Test {}; + +TEST_F(EncoderUtilsTest, isPadded) { + EXPECT_TRUE(isPadded("ABC=")); + EXPECT_FALSE(isPadded("ABC")); +} + +TEST_F(EncoderUtilsTest, numPadding) { + EXPECT_EQ(0, numPadding("ABC")); + EXPECT_EQ(1, numPadding("ABC=")); + EXPECT_EQ(2, numPadding("AB==")); +} + +} // namespace facebook::velox::encoding