diff --git a/velox/common/encode/Base64.cpp b/velox/common/encode/Base64.cpp index 15c7f0ed8ee6..6645638f8340 100644 --- a/velox/common/encode/Base64.cpp +++ b/velox/common/encode/Base64.cpp @@ -161,10 +161,8 @@ std::string Base64::encodeImpl( const T& input, const Charset& charset, bool includePadding) { - const size_t encodedSize{calculateEncodedSize(input.size(), includePadding)}; std::string encodedResult; - encodedResult.resize(encodedSize); - (void)encodeImpl(input, charset, includePadding, encodedResult.data()); + (void)encodeImpl(input, charset, includePadding, encodedResult); return encodedResult; } @@ -184,42 +182,46 @@ size_t Base64::calculateEncodedSize(size_t inputSize, bool includePadding) { } // static -Status Base64::encode(const char* input, size_t inputSize, char* output) { - return encodeImpl( - folly::StringPiece(input, inputSize), kBase64Charset, true, output); +Status Base64::encode(std::string_view input, std::string& output) { + return encodeImpl(input, kBase64Charset, true, output); } // static Status -Base64::encodeUrl(const char* input, size_t inputSize, char* outputBuffer) { +Base64::encodeUrl(std::string_view input, std::string& output) { return encodeImpl( - folly::StringPiece(input, inputSize), + input, kBase64UrlCharset, true, - outputBuffer); + output); } // static template Status Base64::encodeImpl( const T& input, - const Base64::Charset& charset, + const Charset& charset, bool includePadding, - char* outputBuffer) { + std::string& output) { auto inputSize = input.size(); if (inputSize == 0) { + output.clear(); return Status::OK(); } - auto outputPointer = outputBuffer; + // Calculate the output size and resize the string beforehand + size_t outputSize = calculateEncodedSize(inputSize, includePadding); + output.resize(outputSize); // Resize the output string to the required size + + // Use a pointer to write into the pre-allocated buffer + auto outputPointer = output.data(); auto inputIterator = input.begin(); - // For each group of 3 bytes (24 bits) in the input, split that into - // 4 groups of 6 bits and encode that using the supplied charset lookup + // Encode input in chunks of 3 bytes for (; inputSize > 2; inputSize -= 3) { - uint32_t inputBlock = static_cast(*inputIterator++) << 16; - inputBlock |= static_cast(*inputIterator++) << 8; - inputBlock |= static_cast(*inputIterator++); + uint32_t inputBlock = uint8_t(*inputIterator++) << 16; + inputBlock |= uint8_t(*inputIterator++) << 8; + inputBlock |= uint8_t(*inputIterator++); *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; @@ -227,24 +229,22 @@ Status Base64::encodeImpl( *outputPointer++ = charset[inputBlock & 0x3f]; } + // Handle remaining bytes (1 or 2 bytes) if (inputSize > 0) { - // We have either 1 or 2 input bytes left. Encode this similar to the - // above (assuming 0 for all other bytes). Optionally append the '=' - // character if it is requested. - uint32_t inputBlock = static_cast(*inputIterator++) << 16; + uint32_t inputBlock = uint8_t(*inputIterator++) << 16; *outputPointer++ = charset[(inputBlock >> 18) & 0x3f]; if (inputSize > 1) { - inputBlock |= static_cast(*inputIterator) << 8; + inputBlock |= uint8_t(*inputIterator) << 8; *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; *outputPointer++ = charset[(inputBlock >> 6) & 0x3f]; if (includePadding) { - *outputPointer = kPadding; + *outputPointer++ = kPadding; } } else { *outputPointer++ = charset[(inputBlock >> 12) & 0x3f]; if (includePadding) { *outputPointer++ = kPadding; - *outputPointer = kPadding; + *outputPointer++ = kPadding; } } } @@ -252,9 +252,10 @@ Status Base64::encodeImpl( return Status::OK(); } + // static -std::string Base64::encode(folly::StringPiece text) { - return encodeImpl(text, kBase64Charset, true); +std::string Base64::encode(folly::StringPiece input) { + return encodeImpl(input, kBase64Charset, true); } // static @@ -425,7 +426,7 @@ Status Base64::decodeImpl( // Set up input and output pointers const char* inputPtr = input.data(); - char* outputPtr = output.data(); + char* outputPointer = output.data(); Status lookupStatus; // Process full blocks of 4 characters @@ -440,13 +441,13 @@ Status Base64::decodeImpl( return lookupStatus; } - uint32_t currentBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3; - outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); - outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); - outputPtr[2] = static_cast(currentBlock & 0xFF); + uint32_t inputBlock = (val0 << 18) | (val1 << 12) | (val2 << 6) | val3; + outputPointer[0] = static_cast((inputBlock >> 16) & 0xFF); + outputPointer[1] = static_cast((inputBlock >> 8) & 0xFF); + outputPointer[2] = static_cast(inputBlock & 0xFF); inputPtr += 4; - outputPtr += 3; + outputPointer += 3; } // Handle remaining characters (2 or 3 characters at the end) @@ -454,14 +455,14 @@ Status Base64::decodeImpl( if (remaining > 1) { uint8_t val0 = base64ReverseLookup(inputPtr[0], reverseIndex, lookupStatus); uint8_t val1 = base64ReverseLookup(inputPtr[1], reverseIndex, lookupStatus); - uint32_t currentBlock = (val0 << 18) | (val1 << 12); - outputPtr[0] = static_cast((currentBlock >> 16) & 0xFF); + uint32_t inputBlock = (val0 << 18) | (val1 << 12); + outputPointer[0] = static_cast((inputBlock >> 16) & 0xFF); if (remaining == 3) { uint8_t val2 = base64ReverseLookup(inputPtr[2], reverseIndex, lookupStatus); - currentBlock |= (val2 << 6); - outputPtr[1] = static_cast((currentBlock >> 8) & 0xFF); + inputBlock |= (val2 << 6); + outputPointer[1] = static_cast((inputBlock >> 8) & 0xFF); } } diff --git a/velox/common/encode/Base64.h b/velox/common/encode/Base64.h index 358982f6b42f..b969dc52e4ae 100644 --- a/velox/common/encode/Base64.h +++ b/velox/common/encode/Base64.h @@ -42,16 +42,16 @@ class Base64 { // Encoding Functions /// Encodes the input data using Base64 encoding. static std::string encode(const char* input, size_t inputSize); - static std::string encode(folly::StringPiece text); + static std::string encode(folly::StringPiece input); static std::string encode(const folly::IOBuf* inputBuffer); - static Status encode(const char* input, size_t inputSize, char* outputBuffer); + static Status encode(std::string_view input, std::string& outputBuffer); /// Encodes the input data using Base64 URL encoding. static std::string encodeUrl(const char* input, size_t inputSize); static std::string encodeUrl(folly::StringPiece text); static std::string encodeUrl(const folly::IOBuf* inputBuffer); static Status - encodeUrl(const char* input, size_t inputSize, char* outputBuffer); + encodeUrl(std::string_view input, std::string& output); // Decoding Functions /// Decodes the input Base64 encoded string. @@ -69,10 +69,6 @@ class Base64 { std::string& output); static Status decodeUrl(std::string_view input, std::string& output); - // Helper Functions - /// Calculates the encoded size based on input size. - static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); - private: // Checks if the input Base64 string is padded. static inline bool isPadded(std::string_view input) { @@ -107,7 +103,7 @@ class Base64 { const T& input, const Charset& charset, bool includePadding, - char* outputBuffer); + std::string& output); static Status decodeImpl( std::string_view input, @@ -121,6 +117,9 @@ class Base64 { size_t& inputSize, size_t& decodedSize); + // Calculates the encoded size based on input size. + static size_t calculateEncodedSize(size_t inputSize, bool withPadding = true); + VELOX_FRIEND_TEST(Base64Test, isPadded); VELOX_FRIEND_TEST(Base64Test, numPadding); }; diff --git a/velox/functions/prestosql/BinaryFunctions.h b/velox/functions/prestosql/BinaryFunctions.h index d3673add45b0..050ab7a5d42f 100644 --- a/velox/functions/prestosql/BinaryFunctions.h +++ b/velox/functions/prestosql/BinaryFunctions.h @@ -280,8 +280,15 @@ struct ToBase64Function { FOLLY_ALWAYS_INLINE Status call(out_type& result, const arg_type& input) { - result.resize(encoding::Base64::calculateEncodedSize(input.size())); - return encoding::Base64::encode(input.data(), input.size(), result.data()); + std::string_view inputView(input.data(), input.size()); + std::string output; + auto status = encoding::Base64::encode(inputView, output); + if (!status.ok()) { + return status; + } + result.resize(output.size()); + std::memcpy(result.data(), output.data(), output.size()); + return Status::OK(); } }; @@ -328,9 +335,15 @@ struct ToBase64UrlFunction { FOLLY_ALWAYS_INLINE Status call(out_type& result, const arg_type& input) { - result.resize(encoding::Base64::calculateEncodedSize(input.size())); - return encoding::Base64::encodeUrl( - input.data(), input.size(), result.data()); + std::string_view inputView(input.data(), input.size()); + std::string output; + auto status = encoding::Base64::encodeUrl(inputView, output); + if (!status.ok()) { + return status; + } + result.resize(output.size()); + std::memcpy(result.data(), output.data(), output.size()); + return Status::OK(); } };