From 63d31aa650ac45cc0ab1aa613a09d7fd432bb27a Mon Sep 17 00:00:00 2001 From: Ethan Steinberg Date: Sun, 7 Apr 2024 21:22:07 -0700 Subject: [PATCH] Update --- native/BUILD | 17 +- native/WORKSPACE | 14 +- native/main.cc | 2 +- native/perform_etl.cc | 797 ++++++++++++++++++++++++++++-------------- setup.py | 3 - 5 files changed, 550 insertions(+), 283 deletions(-) diff --git a/native/BUILD b/native/BUILD index f689985..d406806 100644 --- a/native/BUILD +++ b/native/BUILD @@ -18,7 +18,7 @@ cc_library( "@com_google_absl//absl/types:optional", "@readerwriterqueue", "@concurrentqueue", - "@snappy", + ":zstd", ], ) @@ -55,8 +55,6 @@ cmake( "EP_COMMON_CMAKE_ARGS": "-DWITH_OPENSSL=OFF", "ARROW_DEPENDENCY_SOURCE": "BUNDLED", }, - deps = [ - ], tags = ["requires-network"], generate_args = ["-DCMAKE_RANLIB=/usr/bin/ranlib"], working_directory="cpp", @@ -64,5 +62,16 @@ cmake( out_lib_dir = "lib64", out_static_libs = ["libparquet.a", "libarrow.a", "libarrow_bundled_dependencies.a"], linkopts = ["-pthread"], - ) + +cmake( + name = "zstd", + cache_entries = { + "CMAKE_C_FLAGS": "-fPIC", + "CMAKE_CXX_FLAGS": "-fPIC", + }, + working_directory="build/cmake", + lib_source = "@zstd//:all", + out_lib_dir = "lib64", + out_static_libs = ["libzstd.a"], +) \ No newline at end of file diff --git a/native/WORKSPACE b/native/WORKSPACE index a7edfa4..8599c33 100644 --- a/native/WORKSPACE +++ b/native/WORKSPACE @@ -8,13 +8,6 @@ http_archive( sha256 = "efa465b26da194f82320b4e39e2ca637ebe3129d7f3732ee71d9942099e3c773", ) -http_archive( - name = "snappy", - urls = ["https://github.com/google/snappy/archive/27f34a580be4a3becf5f8c0cba13433f53c21337.zip"], - strip_prefix = "snappy-27f34a580be4a3becf5f8c0cba13433f53c21337", - sha256 = "6c18c64264a2b15531896824b85c47e85da3e42059631661fbe0c588e4e3aa99", -) - new_git_repository( name = "concurrentqueue", remote = "https://github.com/cameron314/concurrentqueue.git", @@ -90,3 +83,10 @@ http_archive( sha256 = "c0ebc13e9bc428c0c9f617444dcd161140e2bf9a7c65c5ae6bf239e5567175fd", patches = ["//patches:ThirdpartyToolchain.cmake.patch"], ) + +http_archive( + name="zstd", + strip_prefix="zstd-1.5.6", + urls = ["https://github.com/facebook/zstd/archive/refs/tags/v1.5.6.zip"], + build_file_content = all_content, +) \ No newline at end of file diff --git a/native/main.cc b/native/main.cc index d3f2c0f..4c81b3c 100644 --- a/native/main.cc +++ b/native/main.cc @@ -6,7 +6,7 @@ int main() { std::string path_to_folder = "/labs/shahlab/projects/ethanid/mimic_test/mimic_demo/temp"; std::string output = - "/labs/shahlab/projects/ethanid/optimize_etl/meds_etl_native/native/" + "/labs/shahlab/projects/ethanid/optimize_etl/meds_etl_cpp/native/" "output"; perform_etl(path_to_folder, output, num_shards); diff --git a/native/perform_etl.cc b/native/perform_etl.cc index 7d76510..f007613 100644 --- a/native/perform_etl.cc +++ b/native/perform_etl.cc @@ -21,16 +21,16 @@ #include "arrow/table.h" #include "arrow/util/type_fwd.h" #include "blockingconcurrentqueue.h" +#include "lightweightsemaphore.h" #include "parquet/arrow/reader.h" #include "parquet/arrow/schema.h" #include "parquet/arrow/writer.h" -#include "readerwritercircularbuffer.h" -#include "snappy.h" +#include "zstd.h" namespace fs = std::filesystem; -const size_t SHARD_PIECE_SIZE = 1000 * 1000 * 1000; // Roughly 1 gigabyte -const size_t SNAPPY_BUFFER_SIZE = 4 * 1000 * 1000; // Roughly 4 megabytes +const size_t SHARD_PIECE_SIZE = 2 * 1000 * 1000 * 1000; // Roughly 2 gigabytes +const size_t COMPRESSION_BUFFER_SIZE = 4 * 1000 * 1000; // Roughly 4 megabytes std::vector> get_fields_for_file( arrow::MemoryPool* pool, const std::string& filename) { @@ -74,14 +74,10 @@ const std::vector known_fields = { "patient_id", "time", "code", "numeric_value", "datetime_value", "text_value"}; -bool is_string_type(const arrow::Field& field) { - return field.type()->Equals(arrow::LargeStringType()); -} - -std::set get_metadata_fields( - const std::vector& files) { +std::set>> +get_metadata_fields(const std::vector& files) { arrow::MemoryPool* pool = arrow::default_memory_pool(); - std::set result; + std::set>> result; for (const auto& file : files) { auto fields = get_fields_for_file(pool, file); @@ -95,13 +91,7 @@ std::set get_metadata_fields( if (std::find(std::begin(known_fields), std::end(known_fields), field->name()) == std::end(known_fields)) { - if (!is_string_type(*field)) { - throw std::runtime_error( - "The C++ MEDS-Flat ETL only supports large_string " - "metadata for now, but found " + - field->ToString()); - } - result.insert(field->name()); + result.insert(std::make_pair(field->name(), field->type())); } } } @@ -109,10 +99,13 @@ std::set get_metadata_fields( return result; } -std::set get_metadata_fields_multithreaded( - const std::vector& files, size_t num_threads) { +std::set>> +get_metadata_fields_multithreaded(const std::vector& files, + size_t num_threads) { std::vector threads; - std::vector> results(num_threads); + std::vector< + std::set>>> + results(num_threads); size_t files_per_thread = (files.size() + num_threads - 1) / num_threads; @@ -131,7 +124,7 @@ std::set get_metadata_fields_multithreaded( thread.join(); } - std::set result; + std::set>> result; for (auto& res : results) { result.merge(std::move(res)); @@ -154,7 +147,6 @@ microsecond precision absl::optional text_value; std::vector> metadata_columns; }; -*/ struct Row { int64_t patient_id; @@ -167,6 +159,7 @@ struct Row { std::make_pair(rhs.patient_id, rhs.time); } }; +*/ template void add_literal_to_vector(std::vector& data, T to_add) { @@ -179,19 +172,28 @@ void add_string_to_vector(std::vector& data, std::string_view to_add) { data.insert(std::end(data), std::begin(to_add), std::end(to_add)); } -using QueueItem = absl::optional; -using QueueType = moodycamel::BlockingReaderWriterCircularBuffer; +using QueueItem = absl::optional>; + +constexpr ssize_t SEMAPHORE_BLOCK_SIZE = 1000; -void sort_reader( +void shard_reader( size_t reader_index, size_t num_shards, moodycamel::BlockingConcurrentQueue>& file_queue, - std::vector< - std::vector>>& + std::vector>& all_write_queues, - const std::vector& metadata_columns) { + moodycamel::LightweightSemaphore& all_write_semaphore, + const std::vector>>& + metadata_columns) { arrow::MemoryPool* pool = arrow::default_memory_pool(); + std::vector ptoks; + for (size_t i = 0; i < num_shards; i++) { + ptoks.emplace_back(all_write_queues[i]); + } + + ssize_t slots_to_write = all_write_semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); + absl::optional item; while (true) { file_queue.wait_dequeue(item); @@ -229,6 +231,9 @@ void sort_reader( std::vector metadata_indices(metadata_columns.size(), -1); + std::bitset::digits> + is_text_metadata; + const auto& manifest = arrow_reader->manifest(); for (const auto& schema_field : manifest.schema_fields) { if (schema_field.children.size() != 0 || @@ -258,7 +263,8 @@ void sort_reader( } time_index = schema_field.column_index; } else if (schema_field.field->name() == "code") { - if (!is_string_type(*(schema_field.field))) { + if (!schema_field.field->type()->Equals( + arrow::LargeStringType())) { throw std::runtime_error( "The C++ MEDS-Flat ETL requires large_string codes " "but found " + @@ -285,7 +291,8 @@ void sort_reader( } datetime_value_index = schema_field.column_index; } else if (schema_field.field->name() == "text_value") { - if (!is_string_type(*(schema_field.field))) { + if (!schema_field.field->type()->Equals( + arrow::LargeStringType())) { throw std::runtime_error( "C++ MEDS-Flat requires Float32 numeric_value but " "found " + @@ -294,16 +301,40 @@ void sort_reader( text_value_index = schema_field.column_index; } else { // Must be metadata - auto iter = std::find(std::begin(metadata_columns), - std::end(metadata_columns), - schema_field.field->name()); - if (!is_string_type(*(schema_field.field))) { + auto iter = std::find_if( + std::begin(metadata_columns), + std::end(metadata_columns), [&](const auto& entry) { + return entry.first == schema_field.field->name(); + }); + if (iter == std::end(metadata_columns)) { + throw std::runtime_error( + "Had an extra column in the metadata that " + "shouldn't exist? " + + schema_field.field->ToString()); + } + + if (!schema_field.field->type()->Equals(iter->second)) { throw std::runtime_error( "C++ MEDS-Flat requires large_string metadata but " "found " + schema_field.field->ToString()); } + int offset = (iter - std::begin(metadata_columns)); + + if (iter->second->Equals(arrow::LargeStringType())) { + is_text_metadata[offset] = true; + } else { + is_text_metadata[offset] = false; + + if (iter->second->byte_width() == -1) { + throw std::runtime_error( + "Found non text metadata with unknown byte " + "width? " + + iter->second->ToString()); + } + } + metadata_indices[offset] = schema_field.column_index; } } @@ -379,31 +410,48 @@ void sort_reader( } std::vector> - metadata_arrays(metadata_columns.size()); + text_metadata_arrays(metadata_columns.size()); + std::vector> + primitive_metadata_arrays(metadata_columns.size()); for (size_t i = 0; i < metadata_columns.size(); i++) { if (metadata_indices[i] == -1) { continue; } - auto metadata_array = - std::dynamic_pointer_cast( - record_batch->column(metadata_indices[i])); - if (!metadata_array) { - throw std::runtime_error( - "Could not cast metadata array " + - metadata_columns[i]); + if (is_text_metadata[i]) { + auto metadata_array = + std::dynamic_pointer_cast( + record_batch->column(metadata_indices[i])); + if (!metadata_array) { + throw std::runtime_error( + "Could not cast metadata array to text" + + metadata_columns[i].first + " " + + metadata_columns[i].second->ToString()); + } + text_metadata_arrays[i] = metadata_array; + } else { + std::shared_ptr fixed_size_array; + PARQUET_ASSIGN_OR_THROW( + fixed_size_array, + record_batch->column(metadata_indices[i]) + ->View(std::make_shared< + arrow::FixedSizeBinaryType>( + metadata_columns[i].second->byte_width()))); + + auto metadata_array = std::dynamic_pointer_cast< + arrow::FixedSizeBinaryArray>(fixed_size_array); + if (!metadata_array) { + throw std::runtime_error( + "Could not cast metadata array to fixed size " + + metadata_columns[i].first + " " + + metadata_columns[i].second->ToString()); + } + primitive_metadata_arrays[i] = metadata_array; } - metadata_arrays[i] = metadata_array; } - std::cout << source << " " << time_array->length() << " " - << text_value_array->length() << " " - << numeric_value_array->length() << std::endl; - for (int64_t i = 0; i < text_value_array->length(); i++) { - Row row; - if (!patient_id_array->IsValid(i)) { throw std::runtime_error( "patient_id incorrectly has null value " + source); @@ -417,250 +465,142 @@ void sort_reader( "code incorrectly has null value " + source); } - row.patient_id = patient_id_array->Value(i); - row.time = time_array->Value(i); + std::vector data; + + int64_t patient_id = patient_id_array->Value(i); + int64_t time = time_array->Value(i); std::bitset::digits> non_null; - add_literal_to_vector(row.data, row.patient_id); - add_literal_to_vector(row.data, row.time); + add_literal_to_vector(data, patient_id); + add_literal_to_vector(data, time); - add_string_to_vector(row.data, code_array->Value(i)); + add_string_to_vector(data, code_array->Value(i)); if (numeric_value_array->IsValid(i)) { non_null[0] = true; - add_literal_to_vector(row.data, + add_literal_to_vector(data, numeric_value_array->Value(i)); } if (datetime_value_array->IsValid(i)) { non_null[1] = true; - add_literal_to_vector(row.data, + add_literal_to_vector(data, datetime_value_array->Value(i)); } if (text_value_array->IsValid(i)) { non_null[2] = true; - add_string_to_vector(row.data, - text_value_array->Value(i)); + add_string_to_vector(data, text_value_array->Value(i)); } for (size_t j = 0; j < metadata_columns.size(); j++) { - if (metadata_arrays[j] && - metadata_arrays[j]->IsValid(i)) { - non_null[3 + j] = true; - add_string_to_vector(row.data, - metadata_arrays[j]->Value(i)); + if (is_text_metadata[j]) { + if (text_metadata_arrays[j] && + text_metadata_arrays[j]->IsValid(i)) { + non_null[3 + j] = true; + add_string_to_vector( + data, text_metadata_arrays[j]->Value(i)); + } + } else { + if (primitive_metadata_arrays[j] && + primitive_metadata_arrays[j]->IsValid(i)) { + non_null[3 + j] = true; + add_string_to_vector( + data, + primitive_metadata_arrays[j]->GetView(i)); + } } } - add_literal_to_vector(row.data, non_null.to_ullong()); + add_literal_to_vector(data, non_null.to_ullong()); size_t index = - std::hash()(row.patient_id) % num_shards; - all_write_queues[index][reader_index].wait_enqueue( - std::move(row)); + std::hash()(patient_id) % num_shards; + all_write_queues[index].enqueue(ptoks[index], + std::move(data)); + + slots_to_write--; + if (slots_to_write == 0) { + slots_to_write = + all_write_semaphore.waitMany(SEMAPHORE_BLOCK_SIZE); + } } } } } for (size_t j = 0; j < num_shards; j++) { - all_write_queues[j][reader_index].wait_enqueue(absl::nullopt); + all_write_queues[j].enqueue(ptoks[j], absl::nullopt); } -} - -template -void dequeue_many_loop(T& in_queues, F f) { - std::vector good_indices; - good_indices.reserve(in_queues.size()); - for (size_t i = 0; i < in_queues.size(); i++) { - good_indices.push_back(i); - } - - typename T::value_type::value_type next_entry; - while (good_indices.size() > 0) { - for (size_t i = 1; i <= good_indices.size(); i++) { - size_t index = good_indices[i - 1]; - while (true) { - bool found = in_queues[index].try_dequeue(next_entry); - - if (!found) { - break; - } - - if (!next_entry) { - std::swap(good_indices[i - 1], good_indices.back()); - good_indices.pop_back(); - i -= 1; - break; - } else { - f(*next_entry); - } - } - } + if (slots_to_write > 0) { + all_write_semaphore.signal(slots_to_write); } } -void sort_writer( - size_t writer_index, size_t num_shards, - std::vector>& - write_queues, - const std::filesystem::path& target_dir) { - std::filesystem::create_directory(target_dir); - - std::vector rows; - std::vector> row_indices; - - size_t current_size = 0; - - size_t current_file_index = 0; - - std::vector uncompressed_data; - std::vector compressed_data; - - auto flush_file = [&]() { - auto target_file = target_dir / std::to_string(current_file_index); - - std::sort(std::begin(row_indices), std::end(row_indices)); - - std::ofstream writer(target_file, - std::ofstream::binary | std::ofstream::out); - - auto flush_compressed = [&]() { - if (compressed_data.size() < - snappy::MaxCompressedLength(uncompressed_data.size())) { - compressed_data.resize( - snappy::MaxCompressedLength(uncompressed_data.size()) * 2); - } - - size_t compressed_length; - snappy::RawCompress(uncompressed_data.data(), - uncompressed_data.size(), - compressed_data.data(), &compressed_length); - - writer.write(reinterpret_cast(&compressed_length), - sizeof(compressed_length)); - writer.write(compressed_data.data(), compressed_length); - - uncompressed_data.clear(); - }; +class ZstdRowWriter { + public: + ZstdRowWriter(const std::string& path, ZSTD_CCtx* ctx) + : fname(path), + fstream(path, std::ifstream::out | std::ifstream::binary), + context(ctx) {} - for (const auto& row_index : row_indices) { - const auto& row_to_insert = rows[std::get<2>(row_index)]; - add_string_to_vector(uncompressed_data, - std::string_view(row_to_insert.data.data(), - row_to_insert.data.size())); + void add_next(std::string_view data) { + add_string_to_vector(uncompressed_buffer, + std::string_view(data.data(), data.size())); - if (uncompressed_data.size() > SNAPPY_BUFFER_SIZE) { - flush_compressed(); - } - } - - if (uncompressed_data.size() > 0) { + if (uncompressed_buffer.size() > COMPRESSION_BUFFER_SIZE) { flush_compressed(); } - - rows.clear(); - row_indices.clear(); - - current_size = 0; - }; - - dequeue_many_loop(write_queues, [&](Row& r) { - current_size += sizeof(size_t) + r.data.size(); - - rows.emplace_back(std::move(r)); - row_indices.emplace_back( - std::make_tuple(r.patient_id, r.time, row_indices.size())); - - if (current_size > SHARD_PIECE_SIZE) { - flush_file(); - } - }); - - if (current_size > 0) { - flush_file(); } -} - -const int QUEUE_SIZE = 1000; - -std::vector sort_and_shard( - const std::filesystem::path& source_directory, - const std::filesystem::path& target_directory, size_t num_shards) { - std::filesystem::create_directory(target_directory); - std::vector paths; - - for (const auto& entry : fs::directory_iterator(source_directory)) { - paths.push_back(entry.path()); - } - - auto set_metadata_fields = - get_metadata_fields_multithreaded(paths, num_shards); - - std::vector metadata_columns(std::begin(set_metadata_fields), - std::end(set_metadata_fields)); - - if (metadata_columns.size() + 3 > - std::numeric_limits::digits) { - throw std::runtime_error( - "C++ MEDS-ETL currently only supports at most " + - std::to_string(std::numeric_limits::digits) + - " metadata columns"); + ~ZstdRowWriter() { + if (uncompressed_buffer.size() > 0) { + flush_compressed(); + } } - moodycamel::BlockingConcurrentQueue> file_queue; + const std::string fname; - for (const auto& path : paths) { - file_queue.enqueue(path); - } + private: + void flush_compressed() { + size_t needed_size = ZSTD_compressBound(uncompressed_buffer.size()); - for (size_t i = 0; i < num_shards; i++) { - file_queue.enqueue({}); - } + if (compressed_buffer.size() < needed_size) { + compressed_buffer.resize(needed_size * 2); + } - std::vector< - std::vector>> - write_queues(num_shards); + size_t compressed_length = ZSTD_compressCCtx( + context, compressed_buffer.data(), compressed_buffer.size(), + uncompressed_buffer.data(), uncompressed_buffer.size(), 1); - for (size_t i = 0; i < num_shards; i++) { - for (size_t j = 0; j < num_shards; j++) { - write_queues[i].emplace_back(QUEUE_SIZE); + if (ZSTD_isError(compressed_length)) { + throw std::runtime_error("Could not compress using zstd?"); } - } - std::vector threads; + fstream.write(reinterpret_cast(&compressed_length), + sizeof(compressed_length)); + fstream.write(compressed_buffer.data(), compressed_length); - for (size_t i = 0; i < num_shards; i++) { - threads.emplace_back( - [i, &file_queue, &write_queues, num_shards, &metadata_columns]() { - sort_reader(i, num_shards, file_queue, write_queues, - metadata_columns); - }); - - threads.emplace_back( - [i, &write_queues, num_shards, target_directory]() { - sort_writer(i, num_shards, write_queues[i], - target_directory / std::to_string(i)); - }); + uncompressed_buffer.clear(); } - for (auto& thread : threads) { - thread.join(); - } + std::ofstream fstream; - return metadata_columns; -} + ZSTD_CCtx* context; + + std::vector compressed_buffer; + std::vector uncompressed_buffer; +}; -class SnappyRowReader { +class ZstdRowReader { public: - SnappyRowReader(const std::string& path) + ZstdRowReader(const std::string& path, ZSTD_DCtx* ctx) : fname(path), fstream(path, std::ifstream::in | std::ifstream::binary), + context(ctx), current_offset(0), uncompressed_size(0) {} @@ -714,43 +654,324 @@ class SnappyRowReader { fstream.read(compressed_buffer.data(), size); - bool is_valid = snappy::GetUncompressedLength(compressed_buffer.data(), - size, &uncompressed_size); - if (!is_valid) { + uncompressed_size = + ZSTD_getFrameContentSize(compressed_buffer.data(), size); + + if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR || + uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) { throw std::runtime_error( - "Could not get size of compressed snappy data?"); + "Could not get the size of the zstd compressed stream?"); } if (uncompressed_buffer.size() < uncompressed_size) { uncompressed_buffer.resize(uncompressed_size * 2); } - is_valid = snappy::RawUncompress(compressed_buffer.data(), size, - uncompressed_buffer.data()); - if (!is_valid) { - throw std::runtime_error("Could not decompress snappy data?"); + size_t read_size = ZSTD_decompressDCtx( + context, uncompressed_buffer.data(), uncompressed_size, + compressed_buffer.data(), size); + + if (ZSTD_isError(read_size) || read_size != uncompressed_size) { + throw std::runtime_error("Could not decompress zstd data?"); } current_offset = 0; return true; } - std::string fname; + const std::string fname; std::ifstream fstream; + ZSTD_DCtx* context; + std::vector compressed_buffer; std::vector uncompressed_buffer; size_t current_offset; size_t uncompressed_size; }; -void join_and_write_single(const std::filesystem::path& source_directory, - const std::filesystem::path& target_path, - const std::vector& metadata_columns) { +void shard_writer( + size_t writer_index, size_t num_shards, + moodycamel::BlockingConcurrentQueue& write_queue, + moodycamel::LightweightSemaphore& write_semaphore, + const std::filesystem::path& target_dir, + moodycamel::BlockingConcurrentQueue& sort_file_queue, + std::atomic& remaining_live_writers) { + std::filesystem::create_directory(target_dir); + + size_t current_size = 0; + + size_t current_file_index = 0; + + auto context_deleter = [](ZSTD_CCtx* context) { ZSTD_freeCCtx(context); }; + + std::unique_ptr context{ + ZSTD_createCCtx(), context_deleter}; + + std::optional current_writer; + + auto init_file = [&]() { + auto target_file = target_dir / std::to_string(current_file_index); + current_writer.emplace(target_file, context.get()); + }; + + auto flush_file = [&]() { + std::string target_file = current_writer->fname; + current_writer.reset(); + + sort_file_queue.enqueue(target_file); + + current_size = 0; + current_file_index += 1; + }; + + QueueItem item; + size_t readers_remaining = num_shards; + + moodycamel::ConsumerToken ctok(write_queue); + + size_t num_read = 0; + + while (true) { + write_queue.wait_dequeue(ctok, item); + + if (!item) { + readers_remaining--; + if (readers_remaining == 0) { + break; + } else { + continue; + } + } + + num_read++; + if (num_read == SEMAPHORE_BLOCK_SIZE) { + write_semaphore.signal(num_read); + num_read = 0; + } + + std::vector& r = *item; + + if (!current_writer) { + init_file(); + } + + current_writer->add_next(std::string_view(r.data(), r.size())); + + current_size += sizeof(size_t) + r.size(); + + if (current_size > SHARD_PIECE_SIZE) { + flush_file(); + } + } + + write_semaphore.signal(num_read); + + if (current_writer) { + flush_file(); + } + + remaining_live_writers.fetch_sub(1, std::memory_order_release); +} + +void shard_sort( + moodycamel::BlockingConcurrentQueue& sort_file_queue, + const std::atomic& remaining_live_writers) { + absl::optional next_entry; + + std::vector data; + std::vector> row_indices; + + auto compression_context_deleter = [](ZSTD_CCtx* context) { + ZSTD_freeCCtx(context); + }; + auto decompression_context_deleter = [](ZSTD_DCtx* context) { + ZSTD_freeDCtx(context); + }; + + std::unique_ptr + compression_context{ZSTD_createCCtx(), compression_context_deleter}; + + std::unique_ptr + decompression_context{ZSTD_createDCtx(), decompression_context_deleter}; + + std::string filename; + while (true) { + while (true) { + bool found = sort_file_queue.wait_dequeue_timed(filename, 1e6); + if (found) { + break; + } + + // No items are available. This could be due to being fully done. + + // Check if we are done + if (remaining_live_writers.load(std::memory_order_acquire) == 0) { + return; + } + + // Need to wait more + } + + data.clear(); + row_indices.clear(); + + { + ZstdRowReader reader(filename, decompression_context.get()); + + while (true) { + auto next = reader.get_next(); + + if (!next) { + break; + } + + size_t start = data.size(); + size_t length = std::get<2>(*next).size(); + + data.insert(std::end(data), std::begin(std::get<2>(*next)), + std::end(std::get<2>(*next))); + row_indices.push_back(std::make_tuple( + std::get<0>(*next), std::get<1>(*next), start, length)); + } + } + + std::sort(std::begin(row_indices), std::end(row_indices)); + + { + ZstdRowWriter writer(filename, compression_context.get()); + + for (const auto& row_index : row_indices) { + writer.add_next( + std::string_view(data.data() + std::get<2>(row_index), + std::get<3>(row_index))); + } + } + } +} + +constexpr int QUEUE_SIZE = 10000; + +std::vector>> +sort_and_shard(const std::filesystem::path& source_directory, + const std::filesystem::path& target_directory, + size_t num_shards) { + std::filesystem::create_directory(target_directory); + + std::vector paths; + + for (const auto& entry : fs::directory_iterator(source_directory)) { + paths.push_back(entry.path()); + } + + auto set_metadata_fields = + get_metadata_fields_multithreaded(paths, num_shards); + + std::vector>> + metadata_columns(std::begin(set_metadata_fields), + std::end(set_metadata_fields)); + std::sort(std::begin(metadata_columns), std::end(metadata_columns)); + + metadata_columns.erase( + std::unique(std::begin(metadata_columns), std::end(metadata_columns), + [](const auto& a, const auto& b) { + return (a.first == b.first) && + a.second->Equals(b.second); + }), + std::end(metadata_columns)); + + for (ssize_t i = 0; i < static_cast(metadata_columns.size()) - 1; + i++) { + if (metadata_columns[i].first == metadata_columns[i + 1].first) { + throw std::runtime_error( + "Got conflicting types for column " + + metadata_columns[i].first + + ", types: " + metadata_columns[i].second->ToString() + " vs " + + metadata_columns[i + 1].second->ToString()); + } + } + + if (metadata_columns.size() + 3 > + std::numeric_limits::digits) { + throw std::runtime_error( + "C++ MEDS-ETL currently only supports at most " + + std::to_string(std::numeric_limits::digits) + + " metadata columns"); + } + + moodycamel::BlockingConcurrentQueue> file_queue; + + for (const auto& path : paths) { + file_queue.enqueue(path); + } + + for (size_t i = 0; i < num_shards; i++) { + file_queue.enqueue({}); + } + + std::vector> write_queues( + num_shards); + + std::vector threads; + + moodycamel::LightweightSemaphore write_semaphore(QUEUE_SIZE * num_shards); + + moodycamel::BlockingConcurrentQueue sort_queue; + std::atomic remaining_live_writers(num_shards); + + for (size_t i = 0; i < num_shards; i++) { + threads.emplace_back([i, &file_queue, &write_queues, &write_semaphore, + num_shards, &metadata_columns]() { + shard_reader(i, num_shards, file_queue, write_queues, + write_semaphore, metadata_columns); + }); + + threads.emplace_back([i, &write_queues, &write_semaphore, num_shards, + target_directory, &sort_queue, + &remaining_live_writers]() { + shard_writer(i, num_shards, write_queues[i], write_semaphore, + target_directory / std::to_string(i), sort_queue, + remaining_live_writers); + }); + + threads.emplace_back([&sort_queue, &remaining_live_writers]() { + shard_sort(sort_queue, remaining_live_writers); + }); + } + + for (auto& thread : threads) { + thread.join(); + } + + absl::optional next_entry; + if (sort_queue.try_dequeue(next_entry)) { + // This should not be possible + throw std::runtime_error( + "Had excess unsorted items. This should not be possible"); + } + + return metadata_columns; +} + +void join_and_write_single( + const std::filesystem::path& source_directory, + const std::filesystem::path& target_path, + const std::vector>>& + metadata_columns) { arrow::FieldVector metadata_fields; - for (const auto& metadata_column : metadata_columns) { - metadata_fields.push_back(arrow::field( - metadata_column, std::make_shared())); + std::bitset::digits> + is_text_metadata; + for (size_t i = 0; i < metadata_columns.size(); i++) { + const auto& metadata_column = metadata_columns[i]; + if (metadata_column.second->Equals(arrow::LargeStringType())) { + is_text_metadata[i] = true; + metadata_fields.push_back(arrow::field( + metadata_column.first, std::make_shared())); + } else { + is_text_metadata[i] = false; + metadata_fields.push_back( + arrow::field(metadata_column.first, metadata_column.second)); + } } auto metadata_type = std::make_shared(metadata_fields); @@ -778,12 +999,13 @@ void join_and_write_single(const std::filesystem::path& source_directory, std::make_shared(measurement_type)), }; auto event_type = std::make_shared(event_type_fields); + auto events_type = std::make_shared(event_type); auto schema_fields = { arrow::field("patient_id", std::make_shared()), arrow::field("static_measurements", std::make_shared()), - arrow::field("events", std::make_shared(event_type)), + arrow::field("events", events_type), }; auto schema = std::make_shared(schema_fields); @@ -813,13 +1035,22 @@ void join_and_write_single(const std::filesystem::path& source_directory, writer, parquet::arrow::FileWriter::Open(*schema, pool, outfile, props, arrow_props)); - std::vector source_files; + std::vector source_files; + + auto context_deleter = [](ZSTD_DCtx* context) { ZSTD_freeDCtx(context); }; + + std::unique_ptr context{ + ZSTD_createDCtx(), context_deleter}; for (const auto& entry : fs::directory_iterator(source_directory)) { - source_files.emplace_back(entry.path()); + source_files.emplace_back(entry.path(), context.get()); } - std::priority_queue> + typedef std::tuple + PriorityQueueItem; + + std::priority_queue, + std::greater> queue; for (size_t i = 0; i < source_files.size(); i++) { @@ -844,16 +1075,28 @@ void join_and_write_single(const std::filesystem::path& source_directory, auto datetime_value_builder = std::make_shared(timestamp_type, pool); - std::vector> metadata_builders; - std::vector> metadata_builders_generic; + std::vector> text_metadata_builders( + metadata_columns.size()); + std::vector> + primitive_metadata_builders(metadata_columns.size()); + std::vector> metadata_builders( + metadata_columns.size()); for (size_t i = 0; i < metadata_columns.size(); i++) { - auto builder = std::make_shared(pool); - metadata_builders.push_back(builder); - metadata_builders_generic.push_back(builder); + if (is_text_metadata[i]) { + auto builder = std::make_shared(pool); + text_metadata_builders[i] = builder; + metadata_builders[i] = builder; + } else { + auto builder = std::make_shared( + std::make_shared( + metadata_columns[i].second->byte_width())); + primitive_metadata_builders[i] = builder; + metadata_builders[i] = builder; + } } auto metadata_builder = std::make_shared( - metadata_type, pool, metadata_builders_generic); + metadata_type, pool, metadata_builders); std::vector> measurement_builder_fields{code_builder, text_value_builder, @@ -880,7 +1123,10 @@ void join_and_write_single(const std::filesystem::path& source_directory, PARQUET_THROW_NOT_OK(patient_id_builder->Finish(columns.data() + 0)); PARQUET_THROW_NOT_OK( static_measurements_builder->Finish(columns.data() + 1)); - PARQUET_THROW_NOT_OK(events_builder->Finish(columns.data() + 2)); + + std::shared_ptr events_array; + PARQUET_THROW_NOT_OK(events_builder->Finish(&events_array)); + PARQUET_ASSIGN_OR_THROW(columns[2], events_array->View(events_type)); std::shared_ptr table = arrow::Table::Make(schema, columns); @@ -976,11 +1222,24 @@ void join_and_write_single(const std::filesystem::path& source_directory, size_t size = *reinterpret_cast( patient_record.substr(offset).data()); offset += sizeof(size); - PARQUET_THROW_NOT_OK(metadata_builders[j]->Append( - patient_record.substr(offset, size))); + auto entry = patient_record.substr(offset, size); + + if (is_text_metadata[j]) { + PARQUET_THROW_NOT_OK( + text_metadata_builders[j]->Append(entry)); + } else { + PARQUET_THROW_NOT_OK( + primitive_metadata_builders[j]->Append(entry)); + } offset += size; } else { - PARQUET_THROW_NOT_OK(metadata_builders[j]->AppendNull()); + if (is_text_metadata[j]) { + PARQUET_THROW_NOT_OK( + text_metadata_builders[j]->AppendNull()); + } else { + PARQUET_THROW_NOT_OK( + primitive_metadata_builders[j]->AppendNull()); + } } } @@ -1001,9 +1260,11 @@ void join_and_write_single(const std::filesystem::path& source_directory, PARQUET_THROW_NOT_OK(writer->Close()); } -void join_and_write(const std::filesystem::path& source_directory, - const std::filesystem::path& target_directory, - const std::vector& metadata_columns) { +void join_and_write( + const std::filesystem::path& source_directory, + const std::filesystem::path& target_directory, + const std::vector>>& + metadata_columns) { std::filesystem::create_directory(target_directory); std::vector shards; @@ -1048,4 +1309,4 @@ void perform_etl(const std::string& source_directory, join_and_write(shard_path, data_path, metadata_columns); fs::remove_all(shard_path); -} \ No newline at end of file +} diff --git a/setup.py b/setup.py index 66a14ac..9ad2427 100644 --- a/setup.py +++ b/setup.py @@ -60,9 +60,6 @@ def build_extensions(self) -> None: if source_env.get("DISTDIR"): extra_args.extend(["--distdir", source_env["DISTDIR"]]) - if has_nvcc(): - extra_args.extend(["--//:cuda=enabled"]) - if source_env.get("MACOSX_DEPLOYMENT_TARGET"): extra_args.extend(["--macos_minimum_os", source_env["MACOSX_DEPLOYMENT_TARGET"]])