Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[improve](function) opt aes_encrypt/decrypt function to handle const column #37194

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 205 additions & 60 deletions be/src/vec/functions/function_encryption.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "vec/columns/columns_number.h"
#include "vec/common/assert_cast.h"
#include "vec/common/pod_array.h"
#include "vec/common/string_ref.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
#include "vec/core/column_with_type_and_name.h"
Expand Down Expand Up @@ -110,54 +111,44 @@ class FunctionEncryptionAndDecrypt : public IFunction {

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
size_t argument_size = arguments.size();
std::vector<ColumnPtr> argument_columns(argument_size);
std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
std::vector<const ColumnString::Chars*> chars_list(argument_size);

auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
auto result_data_column = ColumnString::create();

auto& result_data = result_data_column->get_chars();
auto& result_offset = result_data_column->get_offsets();
result_offset.resize(input_rows_count);

for (int i = 0; i < argument_size; ++i) {
argument_columns[i] =
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
if (auto* nullable = check_and_get_column<ColumnNullable>(*argument_columns[i])) {
VectorizedUtils::update_null_map(result_null_map->get_data(),
nullable->get_null_map_data());
argument_columns[i] = nullable->get_nested_column_ptr();
}
}

for (size_t i = 0; i < argument_size; ++i) {
auto col_str = assert_cast<const ColumnString*>(argument_columns[i].get());
offsets_list[i] = &col_str->get_offsets();
chars_list[i] = &col_str->get_chars();
}

RETURN_IF_ERROR(Impl::vector_vector(offsets_list, chars_list, input_rows_count, result_data,
result_offset, result_null_map->get_data()));
block.get_by_position(result).column =
ColumnNullable::create(std::move(result_data_column), std::move(result_null_map));
return Status::OK();
return Impl::execute_impl_inner(context, block, arguments, result, input_rows_count);
}
};

template <typename Impl, bool is_encrypt>
void exectue_result(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
int src_size = (*offsets_list[0])[i] - (*offsets_list[0])[i - 1];
const auto src_raw =
const auto* src_raw =
reinterpret_cast<const char*>(&(*chars_list[0])[(*offsets_list[0])[i - 1]]);
int key_size = (*offsets_list[1])[i] - (*offsets_list[1])[i - 1];
const auto key_raw =
const auto* key_raw =
reinterpret_cast<const char*>(&(*chars_list[1])[(*offsets_list[1])[i - 1]]);
execute_result<Impl, is_encrypt>(src_raw, src_size, key_raw, key_size, i, encryption_mode,
iv_raw, iv_length, result_data, result_offset, null_map);
}

template <typename Impl, bool is_encrypt>
void execute_result_const(const ColumnString::Offsets* offsets_column,
const ColumnString::Chars* chars_column, StringRef key_arg, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
int src_size = (*offsets_column)[i] - (*offsets_column)[i - 1];
const auto* src_raw = reinterpret_cast<const char*>(&(*chars_column)[(*offsets_column)[i - 1]]);
execute_result<Impl, is_encrypt>(src_raw, src_size, key_arg.data, key_arg.size, i,
encryption_mode, iv_raw, iv_length, result_data, result_offset,
null_map);
}

template <typename Impl, bool is_encrypt>
void execute_result(const char* src_raw, int src_size, const char* key_raw, int key_size, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
if (src_size == 0) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
return;
Expand All @@ -170,7 +161,7 @@ void exectue_result(std::vector<const ColumnString::Offsets*>& offsets_list,
p.reset(new char[cipher_len]);
int ret_code = 0;

ret_code = Impl::exectue_impl(encryption_mode, (unsigned char*)src_raw, src_size,
ret_code = Impl::execute_impl(encryption_mode, (unsigned char*)src_raw, src_size,
(unsigned char*)key_raw, key_size, iv_raw, iv_length, true,
(unsigned char*)p.get());

Expand All @@ -189,18 +180,90 @@ struct EncryptionAndDecryptTwoImpl {
std::make_shared<DataTypeString>()};
}

static Status vector_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list,
size_t input_rows_count, ColumnString::Chars& result_data,
ColumnString::Offsets& result_offset, NullMap& null_map) {
static Status execute_impl_inner(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, size_t result,
size_t input_rows_count) {
auto result_column = ColumnString::create();
auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);
DCHECK_EQ(3, arguments.size());
const size_t argument_size = 3;
bool col_const[argument_size];
ColumnPtr argument_columns[argument_size];
for (int i = 0; i < argument_size; ++i) {
col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column);
}
argument_columns[0] = col_const[0] ? static_cast<const ColumnConst&>(
*block.get_by_position(arguments[0]).column)
.convert_to_full_column()
: block.get_by_position(arguments[0]).column;

default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block, arguments);

for (int i = 0; i < argument_size; i++) {
check_set_nullable(argument_columns[i], result_null_map_column, col_const[i]);
}
auto& result_data = result_column->get_chars();
auto& result_offset = result_column->get_offsets();
result_offset.resize(input_rows_count);

if (col_const[1] && col_const[2]) {
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
input_rows_count, result_data, result_offset,
result_null_map_column->get_data());
} else {
std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
std::vector<const ColumnString::Chars*> chars_list(argument_size);
for (size_t i = 0; i < argument_size; ++i) {
const auto* col_str = assert_cast<const ColumnString*>(argument_columns[i].get());
offsets_list[i] = &col_str->get_offsets();
chars_list[i] = &col_str->get_chars();
}
vector_vector(offsets_list, chars_list, input_rows_count, result_data, result_offset,
result_null_map_column->get_data());
}
block.get_by_position(result).column =
ColumnNullable::create(std::move(result_column), std::move(result_null_map_column));
return Status::OK();
}

static void vector_const(const ColumnString* column, StringRef key_arg, StringRef mode_arg,
size_t input_rows_count, ColumnString::Chars& result_data,
ColumnString::Offsets& result_offset, NullMap& null_map) {
EncryptionMode encryption_mode = mode;
std::string mode_str(mode_arg.data, mode_arg.size);
bool all_insert_null = false;
if (mode_arg.size != 0) {
if (!aes_mode_map.contains(mode_str)) {
all_insert_null = true;
}
encryption_mode = aes_mode_map.at(mode_str);
}
const ColumnString::Offsets* offsets_column = &column->get_offsets();
const ColumnString::Chars* chars_column = &column->get_chars();
for (int i = 0; i < input_rows_count; ++i) {
if (all_insert_null || null_map[i]) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
continue;
}
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
encryption_mode, nullptr, 0, result_data,
result_offset, null_map);
}
}

static void vector_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list,
size_t input_rows_count, ColumnString::Chars& result_data,
ColumnString::Offsets& result_offset, NullMap& null_map) {
for (int i = 0; i < input_rows_count; ++i) {
if (null_map[i]) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
continue;
}
EncryptionMode encryption_mode = mode;
int mode_size = (*offsets_list[2])[i] - (*offsets_list[2])[i - 1];
const auto mode_raw =
const auto* mode_raw =
reinterpret_cast<const char*>(&(*chars_list[2])[(*offsets_list[2])[i - 1]]);
if (mode_size != 0) {
std::string mode_str(mode_raw, mode_size);
Expand All @@ -210,10 +273,10 @@ struct EncryptionAndDecryptTwoImpl {
}
encryption_mode = aes_mode_map.at(mode_str);
}
exectue_result<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode, nullptr,
0, result_data, result_offset, null_map);
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
nullptr, 0, result_data, result_offset,
null_map);
}
return Status::OK();
}
};

Expand All @@ -224,10 +287,92 @@ struct EncryptionAndDecryptFourImpl {
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
}

static Status vector_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list,
size_t input_rows_count, ColumnString::Chars& result_data,
ColumnString::Offsets& result_offset, NullMap& null_map) {
static Status execute_impl_inner(FunctionContext* context, Block& block,
const ColumnNumbers& arguments, size_t result,
size_t input_rows_count) {
auto result_column = ColumnString::create();
auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);
DCHECK_EQ(4, arguments.size());
const size_t argument_size = 4;
bool col_const[argument_size];
ColumnPtr argument_columns[argument_size];
for (int i = 0; i < argument_size; ++i) {
col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column);
}
argument_columns[0] = col_const[0] ? static_cast<const ColumnConst&>(
*block.get_by_position(arguments[0]).column)
.convert_to_full_column()
: block.get_by_position(arguments[0]).column;

default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block,
arguments);

for (int i = 0; i < argument_size; i++) {
check_set_nullable(argument_columns[i], result_null_map_column, col_const[i]);
}
auto& result_data = result_column->get_chars();
auto& result_offset = result_column->get_offsets();
result_offset.resize(input_rows_count);

if (col_const[1] && col_const[2] && col_const[3]) {
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
argument_columns[3]->get_data_at(0), input_rows_count, result_data,
result_offset, result_null_map_column->get_data());
} else {
std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
std::vector<const ColumnString::Chars*> chars_list(argument_size);
for (size_t i = 0; i < argument_size; ++i) {
const auto* col_str = assert_cast<const ColumnString*>(argument_columns[i].get());
offsets_list[i] = &col_str->get_offsets();
chars_list[i] = &col_str->get_chars();
}
vector_vector(offsets_list, chars_list, input_rows_count, result_data, result_offset,
result_null_map_column->get_data());
}
block.get_by_position(result).column =
ColumnNullable::create(std::move(result_column), std::move(result_null_map_column));
return Status::OK();
}

static void vector_const(const ColumnString* column, StringRef iv_arg, StringRef key_arg,
StringRef mode_arg, size_t input_rows_count,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
EncryptionMode encryption_mode = mode;
bool all_insert_null = false;
if (mode_arg.size != 0) {
std::string mode_str(mode_arg.data, mode_arg.size);
if constexpr (is_sm_mode) {
if (sm4_mode_map.count(mode_str) == 0) {
all_insert_null = true;
}
encryption_mode = sm4_mode_map.at(mode_str);
} else {
if (aes_mode_map.count(mode_str) == 0) {
all_insert_null = true;
}
encryption_mode = aes_mode_map.at(mode_str);
}
}

const ColumnString::Offsets* offsets_column = &column->get_offsets();
const ColumnString::Chars* chars_column = &column->get_chars();
for (int i = 0; i < input_rows_count; ++i) {
if (all_insert_null || null_map[i]) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
continue;
}
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
encryption_mode, iv_arg.data, iv_arg.size,
result_data, result_offset, null_map);
}
}

static void vector_vector(std::vector<const ColumnString::Offsets*>& offsets_list,
std::vector<const ColumnString::Chars*>& chars_list,
size_t input_rows_count, ColumnString::Chars& result_data,
ColumnString::Offsets& result_offset, NullMap& null_map) {
for (int i = 0; i < input_rows_count; ++i) {
if (null_map[i]) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
Expand All @@ -237,9 +382,9 @@ struct EncryptionAndDecryptFourImpl {
EncryptionMode encryption_mode = mode;
int mode_size = (*offsets_list[3])[i] - (*offsets_list[3])[i - 1];
int iv_size = (*offsets_list[2])[i] - (*offsets_list[2])[i - 1];
const auto mode_raw =
const auto* mode_raw =
reinterpret_cast<const char*>(&(*chars_list[3])[(*offsets_list[3])[i - 1]]);
const auto iv_raw =
const auto* iv_raw =
reinterpret_cast<const char*>(&(*chars_list[2])[(*offsets_list[2])[i - 1]]);
if (mode_size != 0) {
std::string mode_str(mode_raw, mode_size);
Expand All @@ -258,15 +403,15 @@ struct EncryptionAndDecryptFourImpl {
}
}

exectue_result<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode, iv_raw,
iv_size, result_data, result_offset, null_map);
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
iv_raw, iv_size, result_data, result_offset,
null_map);
}
return Status::OK();
}
};

struct EncryptImpl {
static int exectue_impl(EncryptionMode mode, const unsigned char* source,
static int execute_impl(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
return EncryptionUtil::encrypt(mode, source, source_length, key, key_length, iv, iv_length,
Expand All @@ -275,7 +420,7 @@ struct EncryptImpl {
};

struct DecryptImpl {
static int exectue_impl(EncryptionMode mode, const unsigned char* source,
static int execute_impl(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
return EncryptionUtil::decrypt(mode, source, source_length, key, key_length, iv, iv_length,
Expand Down
Loading