diff --git a/yggdrasil_decision_forests/dataset/BUILD b/yggdrasil_decision_forests/dataset/BUILD index b27d7500..035785c9 100644 --- a/yggdrasil_decision_forests/dataset/BUILD +++ b/yggdrasil_decision_forests/dataset/BUILD @@ -383,6 +383,7 @@ cc_library_ydf( "//yggdrasil_decision_forests/utils:bytestream", "//yggdrasil_decision_forests/utils:filesystem", "//yggdrasil_decision_forests/utils:status_macros", + "//yggdrasil_decision_forests/utils:zlib", "@com_github_tencent_rapidjson//:rapidjson", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", diff --git a/yggdrasil_decision_forests/dataset/avro.cc b/yggdrasil_decision_forests/dataset/avro.cc index ee7d1cc7..393bb419 100644 --- a/yggdrasil_decision_forests/dataset/avro.cc +++ b/yggdrasil_decision_forests/dataset/avro.cc @@ -32,19 +32,18 @@ #include "absl/types/optional.h" #include "include/rapidjson/document.h" #include "include/rapidjson/rapidjson.h" -#include "yggdrasil_decision_forests/dataset/data_spec.h" -#include "yggdrasil_decision_forests/dataset/data_spec.pb.h" -#include "yggdrasil_decision_forests/dataset/data_spec_inference.h" #include "yggdrasil_decision_forests/utils/bytestream.h" #include "yggdrasil_decision_forests/utils/filesystem.h" #include "yggdrasil_decision_forests/utils/status_macros.h" +#include "yggdrasil_decision_forests/utils/zlib.h" -#define MAYBE_SKIP_OPTIONAL(FIELD) \ - if (field.optional) { \ - ASSIGN_OR_RETURN(const auto _has_value, current_block_reader->ReadByte()); \ - if (!_has_value) { \ - return absl::nullopt; \ - } \ +#define MAYBE_SKIP_OPTIONAL(FIELD) \ + if (field.optional) { \ + ASSIGN_OR_RETURN(const auto _has_value, \ + current_block_reader_->ReadByte()); \ + if (!_has_value) { \ + return absl::nullopt; \ + } \ } namespace yggdrasil_decision_forests::dataset::avro { @@ -309,17 +308,24 @@ absl::StatusOr AvroReader::ReadNextBlock() { ASSIGN_OR_RETURN(const auto block_size, internal::ReadInteger(stream_.get())); - current_block.resize(block_size); + current_block_.resize(block_size); ASSIGN_OR_RETURN(bool has_read, - stream_->ReadExactly(¤t_block[0], block_size)); + stream_->ReadExactly(¤t_block_[0], block_size)); if (!has_read) { return absl::InvalidArgumentError("Unexpected end of stream"); } - current_block_reader = utils::StringViewInputByteStream(current_block); - if (codec_ != AvroCodec::kNull) { - // TODO: Implement deflate. - return absl::UnimplementedError("Compression not implemented"); + switch (codec_) { + case AvroCodec::kNull: + current_block_reader_ = utils::StringViewInputByteStream(current_block_); + break; + case AvroCodec::kDeflate: + zlib_working_buffer_.resize(1024 * 1024); + RETURN_IF_ERROR(utils::Inflate( + current_block_, ¤t_block_decompressed_, &zlib_working_buffer_)); + current_block_reader_ = + utils::StringViewInputByteStream(current_block_decompressed_); + break; } new_sync_marker_.resize(16); @@ -334,8 +340,8 @@ absl::StatusOr AvroReader::ReadNextBlock() { } absl::StatusOr AvroReader::ReadNextRecord() { - if (!current_block_reader.has_value() || - current_block_reader.value().left() == 0) { + if (!current_block_reader_.has_value() || + current_block_reader_.value().left() == 0) { // Read a new block of data. DCHECK_EQ(next_object_in_current_block_, num_objects_in_current_block_); ASSIGN_OR_RETURN(const bool has_next_block, ReadNextBlock()); @@ -350,37 +356,37 @@ absl::StatusOr AvroReader::ReadNextRecord() { absl::StatusOr> AvroReader::ReadNextFieldBoolean( const AvroField& field) { MAYBE_SKIP_OPTIONAL(field); - ASSIGN_OR_RETURN(const auto value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto value, current_block_reader_->ReadByte()); return value; } absl::StatusOr> AvroReader::ReadNextFieldInteger( const AvroField& field) { MAYBE_SKIP_OPTIONAL(field); - return internal::ReadInteger(¤t_block_reader.value()); + return internal::ReadInteger(¤t_block_reader_.value()); } absl::StatusOr> AvroReader::ReadNextFieldFloat( const AvroField& field) { MAYBE_SKIP_OPTIONAL(field); - return internal::ReadFloat(¤t_block_reader.value()); + return internal::ReadFloat(¤t_block_reader_.value()); } absl::StatusOr> AvroReader::ReadNextFieldDouble( const AvroField& field) { MAYBE_SKIP_OPTIONAL(field); - return internal::ReadDouble(¤t_block_reader.value()); + return internal::ReadDouble(¤t_block_reader_.value()); } absl::StatusOr AvroReader::ReadNextFieldString(const AvroField& field, std::string* value) { if (field.optional) { - ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte()); if (!has_value) { return false; } } - RETURN_IF_ERROR(internal::ReadString(¤t_block_reader.value(), value)); + RETURN_IF_ERROR(internal::ReadString(¤t_block_reader_.value(), value)); return true; } @@ -388,27 +394,27 @@ absl::StatusOr AvroReader::ReadNextFieldArrayFloat( const AvroField& field, std::vector* values) { values->clear(); if (field.optional) { - ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte()); if (!has_value) { return false; } } while (true) { ASSIGN_OR_RETURN(auto num_values, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); values->reserve(values->size() + num_values); if (num_values == 0) { break; } if (num_values < 0) { ASSIGN_OR_RETURN(auto block_size, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); (void)block_size; num_values = -num_values; } for (size_t value_idx = 0; value_idx < num_values; value_idx++) { ASSIGN_OR_RETURN(auto value, - internal::ReadFloat(¤t_block_reader.value())); + internal::ReadFloat(¤t_block_reader_.value())); values->push_back(value); } } @@ -420,27 +426,27 @@ absl::StatusOr AvroReader::ReadNextFieldArrayDouble( const AvroField& field, std::vector* values) { values->clear(); if (field.optional) { - ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte()); if (!has_value) { return false; } } while (true) { ASSIGN_OR_RETURN(auto num_values, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); values->reserve(values->size() + num_values); if (num_values == 0) { break; } if (num_values < 0) { ASSIGN_OR_RETURN(auto block_size, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); (void)block_size; num_values = -num_values; } for (size_t value_idx = 0; value_idx < num_values; value_idx++) { ASSIGN_OR_RETURN(auto value, - internal::ReadDouble(¤t_block_reader.value())); + internal::ReadDouble(¤t_block_reader_.value())); values->push_back(value); } } @@ -452,27 +458,27 @@ absl::StatusOr AvroReader::ReadNextFieldArrayDoubleIntoFloat( const AvroField& field, std::vector* values) { values->clear(); if (field.optional) { - ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte()); if (!has_value) { return false; } } while (true) { ASSIGN_OR_RETURN(auto num_values, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); if (num_values == 0) { break; } values->reserve(values->size() + num_values); if (num_values < 0) { ASSIGN_OR_RETURN(auto block_size, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); (void)block_size; num_values = -num_values; } for (size_t value_idx = 0; value_idx < num_values; value_idx++) { ASSIGN_OR_RETURN(auto value, - internal::ReadDouble(¤t_block_reader.value())); + internal::ReadDouble(¤t_block_reader_.value())); values->push_back(value); } } @@ -483,7 +489,7 @@ absl::StatusOr AvroReader::ReadNextFieldArrayDoubleIntoFloat( absl::StatusOr AvroReader::ReadNextFieldArrayString( const AvroField& field, std::vector* values) { if (field.optional) { - ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte()); + ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte()); if (!has_value) { return false; } @@ -491,21 +497,21 @@ absl::StatusOr AvroReader::ReadNextFieldArrayString( while (true) { ASSIGN_OR_RETURN(auto num_values, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); if (num_values == 0) { break; } values->reserve(values->size() + num_values); if (num_values < 0) { ASSIGN_OR_RETURN(auto block_size, - internal::ReadInteger(¤t_block_reader.value())); + internal::ReadInteger(¤t_block_reader_.value())); (void)block_size; num_values = -num_values; } for (size_t value_idx = 0; value_idx < num_values; value_idx++) { std::string sub_value; RETURN_IF_ERROR( - internal::ReadString(¤t_block_reader.value(), &sub_value)); + internal::ReadString(¤t_block_reader_.value(), &sub_value)); values->push_back(std::move(sub_value)); } } diff --git a/yggdrasil_decision_forests/dataset/avro.h b/yggdrasil_decision_forests/dataset/avro.h index e2ad88ad..fb40cab8 100644 --- a/yggdrasil_decision_forests/dataset/avro.h +++ b/yggdrasil_decision_forests/dataset/avro.h @@ -132,8 +132,10 @@ class AvroReader { AvroCodec codec_ = AvroCodec::kNull; // Raw and uncompressed data of the current block. - std::string current_block; - absl::optional current_block_reader; + std::string current_block_; + std::string current_block_decompressed_; + std::string zlib_working_buffer_; + absl::optional current_block_reader_; size_t num_objects_in_current_block_ = 0; size_t next_object_in_current_block_ = 0; diff --git a/yggdrasil_decision_forests/dataset/avro_example_test.cc b/yggdrasil_decision_forests/dataset/avro_example_test.cc index 06354abc..22730d96 100644 --- a/yggdrasil_decision_forests/dataset/avro_example_test.cc +++ b/yggdrasil_decision_forests/dataset/avro_example_test.cc @@ -347,7 +347,16 @@ TEST(AvroExample, CreateDataspec) { EXPECT_THAT(dataspec, EqualsProto(expected)); } -TEST(AvroExample, ReadExample) { +struct ReadExampleCase { + std::string filename; +}; + +SIMPLE_PARAMETERIZED_TEST(ReadExample, ReadExampleCase, + { + {"toy_codex-null.avro"}, + {"toy_codex-deflate.avro"}, + }) { + const auto& test_case = GetParam(); dataset::proto::DataSpecificationGuide guide; { auto* col = guide.add_column_guides(); @@ -367,11 +376,10 @@ TEST(AvroExample, ReadExample) { ASSERT_OK_AND_ASSIGN( const auto dataspec, - CreateDataspec(file::JoinPath(DatasetDir(), "toy_codex-null.avro"), - guide)); + CreateDataspec(file::JoinPath(DatasetDir(), test_case.filename), guide)); AvroExampleReader reader(dataspec, {}); - ASSERT_OK(reader.Open(file::JoinPath(DatasetDir(), "toy_codex-null.avro"))); + ASSERT_OK(reader.Open(file::JoinPath(DatasetDir(), test_case.filename))); proto::Example example; ASSERT_OK_AND_ASSIGN(bool has_next, reader.Next(&example)); ASSERT_TRUE(has_next); diff --git a/yggdrasil_decision_forests/utils/zlib.cc b/yggdrasil_decision_forests/utils/zlib.cc index ecb3a6bb..78f3461d 100644 --- a/yggdrasil_decision_forests/utils/zlib.cc +++ b/yggdrasil_decision_forests/utils/zlib.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "absl/log/check.h" @@ -44,19 +45,11 @@ GZipInputByteStream::Create(std::unique_ptr&& stream, size_t buffer_size) { auto gz_stream = std::make_unique(std::move(stream), buffer_size); - - gz_stream->deflate_stream_.zalloc = Z_NULL; - gz_stream->deflate_stream_.zfree = Z_NULL; - gz_stream->deflate_stream_.opaque = Z_NULL; - gz_stream->deflate_stream_.avail_in = 0; - gz_stream->deflate_stream_.next_in = Z_NULL; + std::memset(&gz_stream->deflate_stream_, 0, + sizeof(gz_stream->deflate_stream_)); if (inflateInit2(&gz_stream->deflate_stream_, 16 + MAX_WBITS) != Z_OK) { return absl::InternalError("Cannot initialize gzip stream"); } - // gz_stream->deflate_stream_.next_in = gz_stream->input_buffer_.data(); - // gz_stream->deflate_stream_.avail_in = 0; - // gz_stream->deflate_stream_.next_out = gz_stream->output_buffer_.data(); - // gz_stream->deflate_stream_.avail_out = 0; gz_stream->deflate_stream_is_allocated_ = true; return gz_stream; } @@ -163,12 +156,8 @@ GZipOutputByteStream::Create(std::unique_ptr&& stream, } auto gz_stream = std::make_unique(std::move(stream), buffer_size); - - gz_stream->deflate_stream_.zalloc = Z_NULL; - gz_stream->deflate_stream_.zfree = Z_NULL; - gz_stream->deflate_stream_.opaque = Z_NULL; - gz_stream->deflate_stream_.avail_in = 0; - gz_stream->deflate_stream_.next_in = Z_NULL; + std::memset(&gz_stream->deflate_stream_, 0, + sizeof(gz_stream->deflate_stream_)); if (deflateInit2(&gz_stream->deflate_stream_, compression_level, Z_DEFLATED, MAX_WBITS + 16, /*memLevel=*/8, // 8 is the recommended default @@ -255,6 +244,45 @@ absl::Status GZipOutputByteStream::CloseInflateStream() { return absl::OkStatus(); } +absl::Status Inflate(absl::string_view input, std::string* output, + std::string* working_buffer) { + if (working_buffer->size() < 1024) { + return absl::InvalidArgumentError( + "worker buffer should be at least 1024 bytes"); + } + z_stream stream; + std::memset(&stream, 0, sizeof(stream)); + // Note: A negative window size indicate to use the raw deflate algorithm (!= + // zlib or gzip). + if (inflateInit2(&stream, -15) != Z_OK) { + return absl::InternalError("Cannot initialize gzip stream"); + } + stream.next_in = reinterpret_cast(input.data()); + stream.avail_in = input.size(); + + while (true) { + stream.next_out = reinterpret_cast(&(*working_buffer)[0]); + stream.avail_out = working_buffer->size(); + const auto zlib_error = inflate(&stream, Z_NO_FLUSH); + if (zlib_error != Z_OK && zlib_error != Z_STREAM_END) { + inflateEnd(&stream); + return absl::InternalError(absl::StrCat("Internal error", zlib_error)); + } + if (stream.avail_out == 0) { + break; + } + const size_t produced_bytes = working_buffer->size() - stream.avail_out; + absl::StrAppend(output, + absl::string_view{working_buffer->data(), produced_bytes}); + if (zlib_error == Z_STREAM_END) { + break; + } + } + inflateEnd(&stream); + + return absl::OkStatus(); +} + } // namespace yggdrasil_decision_forests::utils #endif // THIRD_PARTY_YGGDRASIL_DECISION_FORESTS_UTILS_GZIP_H_ diff --git a/yggdrasil_decision_forests/utils/zlib.h b/yggdrasil_decision_forests/utils/zlib.h index bc26ee2f..e3a98b01 100644 --- a/yggdrasil_decision_forests/utils/zlib.h +++ b/yggdrasil_decision_forests/utils/zlib.h @@ -18,10 +18,12 @@ #include #include +#include #include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "yggdrasil_decision_forests/utils/bytestream.h" #define ZLIB_CONST @@ -94,6 +96,9 @@ class GZipOutputByteStream : public utils::OutputByteStream { bool deflate_stream_is_allocated_ = false; }; +absl::Status Inflate(absl::string_view input, std::string* output, + std::string* working_buffer); + } // namespace yggdrasil_decision_forests::utils #endif // THIRD_PARTY_YGGDRASIL_DECISION_FORESTS_UTILS_ZLIB_H_ diff --git a/yggdrasil_decision_forests/utils/zlib_test.cc b/yggdrasil_decision_forests/utils/zlib_test.cc index 512e07d5..359d156a 100644 --- a/yggdrasil_decision_forests/utils/zlib_test.cc +++ b/yggdrasil_decision_forests/utils/zlib_test.cc @@ -25,6 +25,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/log/log.h" +#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #include "yggdrasil_decision_forests/utils/filesystem.h" #include "yggdrasil_decision_forests/utils/logging.h" @@ -138,5 +139,14 @@ TEST_P(GZipTestCaseTest, WriteAndRead) { } } +TEST(RawDeflate, Base) { + const auto input = + absl::HexStringToBytes("05804109000008c4aa184ec1c7e0c08ff5c70ea43e470b"); + std::string output; + std::string working_buffer(1024, 0); + ASSERT_OK(Inflate(input, &output, &working_buffer)); + EXPECT_EQ(output, "hello world"); +} + } // namespace } // namespace yggdrasil_decision_forests::utils