Skip to content

Commit

Permalink
Read Avro files without dependencies (part 6)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683975762
  • Loading branch information
achoum authored and copybara-github committed Oct 9, 2024
1 parent a7be9c1 commit 3843a7a
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 61 deletions.
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:bytestream",
"//yggdrasil_decision_forests/utils:filesystem",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:zlib",
"@com_github_tencent_rapidjson//:rapidjson",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
Expand Down
84 changes: 45 additions & 39 deletions yggdrasil_decision_forests/dataset/avro.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,18 @@
#include "absl/types/optional.h"
#include "include/rapidjson/document.h"
#include "include/rapidjson/rapidjson.h"
#include "yggdrasil_decision_forests/dataset/data_spec.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/data_spec_inference.h"
#include "yggdrasil_decision_forests/utils/bytestream.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
#include "yggdrasil_decision_forests/utils/zlib.h"

#define MAYBE_SKIP_OPTIONAL(FIELD) \
if (field.optional) { \
ASSIGN_OR_RETURN(const auto _has_value, current_block_reader->ReadByte()); \
if (!_has_value) { \
return absl::nullopt; \
} \
#define MAYBE_SKIP_OPTIONAL(FIELD) \
if (field.optional) { \
ASSIGN_OR_RETURN(const auto _has_value, \
current_block_reader_->ReadByte()); \
if (!_has_value) { \
return absl::nullopt; \
} \
}

namespace yggdrasil_decision_forests::dataset::avro {
Expand Down Expand Up @@ -309,17 +308,24 @@ absl::StatusOr<bool> AvroReader::ReadNextBlock() {

ASSIGN_OR_RETURN(const auto block_size, internal::ReadInteger(stream_.get()));

current_block.resize(block_size);
current_block_.resize(block_size);
ASSIGN_OR_RETURN(bool has_read,
stream_->ReadExactly(&current_block[0], block_size));
stream_->ReadExactly(&current_block_[0], block_size));
if (!has_read) {
return absl::InvalidArgumentError("Unexpected end of stream");
}
current_block_reader = utils::StringViewInputByteStream(current_block);

if (codec_ != AvroCodec::kNull) {
// TODO: Implement deflate.
return absl::UnimplementedError("Compression not implemented");
switch (codec_) {
case AvroCodec::kNull:
current_block_reader_ = utils::StringViewInputByteStream(current_block_);
break;
case AvroCodec::kDeflate:
zlib_working_buffer_.resize(1024 * 1024);
RETURN_IF_ERROR(utils::Inflate(
current_block_, &current_block_decompressed_, &zlib_working_buffer_));
current_block_reader_ =
utils::StringViewInputByteStream(current_block_decompressed_);
break;
}

new_sync_marker_.resize(16);
Expand All @@ -334,8 +340,8 @@ absl::StatusOr<bool> AvroReader::ReadNextBlock() {
}

absl::StatusOr<bool> AvroReader::ReadNextRecord() {
if (!current_block_reader.has_value() ||
current_block_reader.value().left() == 0) {
if (!current_block_reader_.has_value() ||
current_block_reader_.value().left() == 0) {
// Read a new block of data.
DCHECK_EQ(next_object_in_current_block_, num_objects_in_current_block_);
ASSIGN_OR_RETURN(const bool has_next_block, ReadNextBlock());
Expand All @@ -350,65 +356,65 @@ absl::StatusOr<bool> AvroReader::ReadNextRecord() {
absl::StatusOr<absl::optional<bool>> AvroReader::ReadNextFieldBoolean(
const AvroField& field) {
MAYBE_SKIP_OPTIONAL(field);
ASSIGN_OR_RETURN(const auto value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto value, current_block_reader_->ReadByte());
return value;
}

absl::StatusOr<absl::optional<int64_t>> AvroReader::ReadNextFieldInteger(
const AvroField& field) {
MAYBE_SKIP_OPTIONAL(field);
return internal::ReadInteger(&current_block_reader.value());
return internal::ReadInteger(&current_block_reader_.value());
}

absl::StatusOr<absl::optional<float>> AvroReader::ReadNextFieldFloat(
const AvroField& field) {
MAYBE_SKIP_OPTIONAL(field);
return internal::ReadFloat(&current_block_reader.value());
return internal::ReadFloat(&current_block_reader_.value());
}

absl::StatusOr<absl::optional<double>> AvroReader::ReadNextFieldDouble(
const AvroField& field) {
MAYBE_SKIP_OPTIONAL(field);
return internal::ReadDouble(&current_block_reader.value());
return internal::ReadDouble(&current_block_reader_.value());
}

absl::StatusOr<bool> AvroReader::ReadNextFieldString(const AvroField& field,
std::string* value) {
if (field.optional) {
ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte());
if (!has_value) {
return false;
}
}
RETURN_IF_ERROR(internal::ReadString(&current_block_reader.value(), value));
RETURN_IF_ERROR(internal::ReadString(&current_block_reader_.value(), value));
return true;
}

absl::StatusOr<bool> AvroReader::ReadNextFieldArrayFloat(
const AvroField& field, std::vector<float>* values) {
values->clear();
if (field.optional) {
ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte());
if (!has_value) {
return false;
}
}
while (true) {
ASSIGN_OR_RETURN(auto num_values,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
values->reserve(values->size() + num_values);
if (num_values == 0) {
break;
}
if (num_values < 0) {
ASSIGN_OR_RETURN(auto block_size,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
(void)block_size;
num_values = -num_values;
}
for (size_t value_idx = 0; value_idx < num_values; value_idx++) {
ASSIGN_OR_RETURN(auto value,
internal::ReadFloat(&current_block_reader.value()));
internal::ReadFloat(&current_block_reader_.value()));
values->push_back(value);
}
}
Expand All @@ -420,27 +426,27 @@ absl::StatusOr<bool> AvroReader::ReadNextFieldArrayDouble(
const AvroField& field, std::vector<double>* values) {
values->clear();
if (field.optional) {
ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte());
if (!has_value) {
return false;
}
}
while (true) {
ASSIGN_OR_RETURN(auto num_values,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
values->reserve(values->size() + num_values);
if (num_values == 0) {
break;
}
if (num_values < 0) {
ASSIGN_OR_RETURN(auto block_size,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
(void)block_size;
num_values = -num_values;
}
for (size_t value_idx = 0; value_idx < num_values; value_idx++) {
ASSIGN_OR_RETURN(auto value,
internal::ReadDouble(&current_block_reader.value()));
internal::ReadDouble(&current_block_reader_.value()));
values->push_back(value);
}
}
Expand All @@ -452,27 +458,27 @@ absl::StatusOr<bool> AvroReader::ReadNextFieldArrayDoubleIntoFloat(
const AvroField& field, std::vector<float>* values) {
values->clear();
if (field.optional) {
ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte());
if (!has_value) {
return false;
}
}
while (true) {
ASSIGN_OR_RETURN(auto num_values,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
if (num_values == 0) {
break;
}
values->reserve(values->size() + num_values);
if (num_values < 0) {
ASSIGN_OR_RETURN(auto block_size,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
(void)block_size;
num_values = -num_values;
}
for (size_t value_idx = 0; value_idx < num_values; value_idx++) {
ASSIGN_OR_RETURN(auto value,
internal::ReadDouble(&current_block_reader.value()));
internal::ReadDouble(&current_block_reader_.value()));
values->push_back(value);
}
}
Expand All @@ -483,29 +489,29 @@ absl::StatusOr<bool> AvroReader::ReadNextFieldArrayDoubleIntoFloat(
absl::StatusOr<bool> AvroReader::ReadNextFieldArrayString(
const AvroField& field, std::vector<std::string>* values) {
if (field.optional) {
ASSIGN_OR_RETURN(const auto has_value, current_block_reader->ReadByte());
ASSIGN_OR_RETURN(const auto has_value, current_block_reader_->ReadByte());
if (!has_value) {
return false;
}
}

while (true) {
ASSIGN_OR_RETURN(auto num_values,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
if (num_values == 0) {
break;
}
values->reserve(values->size() + num_values);
if (num_values < 0) {
ASSIGN_OR_RETURN(auto block_size,
internal::ReadInteger(&current_block_reader.value()));
internal::ReadInteger(&current_block_reader_.value()));
(void)block_size;
num_values = -num_values;
}
for (size_t value_idx = 0; value_idx < num_values; value_idx++) {
std::string sub_value;
RETURN_IF_ERROR(
internal::ReadString(&current_block_reader.value(), &sub_value));
internal::ReadString(&current_block_reader_.value(), &sub_value));
values->push_back(std::move(sub_value));
}
}
Expand Down
6 changes: 4 additions & 2 deletions yggdrasil_decision_forests/dataset/avro.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ class AvroReader {
AvroCodec codec_ = AvroCodec::kNull;

// Raw and uncompressed data of the current block.
std::string current_block;
absl::optional<utils::StringViewInputByteStream> current_block_reader;
std::string current_block_;
std::string current_block_decompressed_;
std::string zlib_working_buffer_;
absl::optional<utils::StringViewInputByteStream> current_block_reader_;

size_t num_objects_in_current_block_ = 0;
size_t next_object_in_current_block_ = 0;
Expand Down
16 changes: 12 additions & 4 deletions yggdrasil_decision_forests/dataset/avro_example_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,16 @@ TEST(AvroExample, CreateDataspec) {
EXPECT_THAT(dataspec, EqualsProto(expected));
}

TEST(AvroExample, ReadExample) {
struct ReadExampleCase {
std::string filename;
};

SIMPLE_PARAMETERIZED_TEST(ReadExample, ReadExampleCase,
{
{"toy_codex-null.avro"},
{"toy_codex-deflate.avro"},
}) {
const auto& test_case = GetParam();
dataset::proto::DataSpecificationGuide guide;
{
auto* col = guide.add_column_guides();
Expand All @@ -367,11 +376,10 @@ TEST(AvroExample, ReadExample) {

ASSERT_OK_AND_ASSIGN(
const auto dataspec,
CreateDataspec(file::JoinPath(DatasetDir(), "toy_codex-null.avro"),
guide));
CreateDataspec(file::JoinPath(DatasetDir(), test_case.filename), guide));

AvroExampleReader reader(dataspec, {});
ASSERT_OK(reader.Open(file::JoinPath(DatasetDir(), "toy_codex-null.avro")));
ASSERT_OK(reader.Open(file::JoinPath(DatasetDir(), test_case.filename)));
proto::Example example;
ASSERT_OK_AND_ASSIGN(bool has_next, reader.Next(&example));
ASSERT_TRUE(has_next);
Expand Down
60 changes: 44 additions & 16 deletions yggdrasil_decision_forests/utils/zlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <string>
#include <utility>

#include "absl/log/check.h"
Expand All @@ -44,19 +45,11 @@ GZipInputByteStream::Create(std::unique_ptr<utils::InputByteStream>&& stream,
size_t buffer_size) {
auto gz_stream =
std::make_unique<GZipInputByteStream>(std::move(stream), buffer_size);

gz_stream->deflate_stream_.zalloc = Z_NULL;
gz_stream->deflate_stream_.zfree = Z_NULL;
gz_stream->deflate_stream_.opaque = Z_NULL;
gz_stream->deflate_stream_.avail_in = 0;
gz_stream->deflate_stream_.next_in = Z_NULL;
std::memset(&gz_stream->deflate_stream_, 0,
sizeof(gz_stream->deflate_stream_));
if (inflateInit2(&gz_stream->deflate_stream_, 16 + MAX_WBITS) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
}
// gz_stream->deflate_stream_.next_in = gz_stream->input_buffer_.data();
// gz_stream->deflate_stream_.avail_in = 0;
// gz_stream->deflate_stream_.next_out = gz_stream->output_buffer_.data();
// gz_stream->deflate_stream_.avail_out = 0;
gz_stream->deflate_stream_is_allocated_ = true;
return gz_stream;
}
Expand Down Expand Up @@ -163,12 +156,8 @@ GZipOutputByteStream::Create(std::unique_ptr<utils::OutputByteStream>&& stream,
}
auto gz_stream =
std::make_unique<GZipOutputByteStream>(std::move(stream), buffer_size);

gz_stream->deflate_stream_.zalloc = Z_NULL;
gz_stream->deflate_stream_.zfree = Z_NULL;
gz_stream->deflate_stream_.opaque = Z_NULL;
gz_stream->deflate_stream_.avail_in = 0;
gz_stream->deflate_stream_.next_in = Z_NULL;
std::memset(&gz_stream->deflate_stream_, 0,
sizeof(gz_stream->deflate_stream_));
if (deflateInit2(&gz_stream->deflate_stream_, compression_level, Z_DEFLATED,
MAX_WBITS + 16,
/*memLevel=*/8, // 8 is the recommended default
Expand Down Expand Up @@ -255,6 +244,45 @@ absl::Status GZipOutputByteStream::CloseInflateStream() {
return absl::OkStatus();
}

absl::Status Inflate(absl::string_view input, std::string* output,
std::string* working_buffer) {
if (working_buffer->size() < 1024) {
return absl::InvalidArgumentError(
"worker buffer should be at least 1024 bytes");
}
z_stream stream;
std::memset(&stream, 0, sizeof(stream));
// Note: A negative window size indicate to use the raw deflate algorithm (!=
// zlib or gzip).
if (inflateInit2(&stream, -15) != Z_OK) {
return absl::InternalError("Cannot initialize gzip stream");
}
stream.next_in = reinterpret_cast<const Bytef*>(input.data());
stream.avail_in = input.size();

while (true) {
stream.next_out = reinterpret_cast<Bytef*>(&(*working_buffer)[0]);
stream.avail_out = working_buffer->size();
const auto zlib_error = inflate(&stream, Z_NO_FLUSH);
if (zlib_error != Z_OK && zlib_error != Z_STREAM_END) {
inflateEnd(&stream);
return absl::InternalError(absl::StrCat("Internal error", zlib_error));
}
if (stream.avail_out == 0) {
break;
}
const size_t produced_bytes = working_buffer->size() - stream.avail_out;
absl::StrAppend(output,
absl::string_view{working_buffer->data(), produced_bytes});
if (zlib_error == Z_STREAM_END) {
break;
}
}
inflateEnd(&stream);

return absl::OkStatus();
}

} // namespace yggdrasil_decision_forests::utils

#endif // THIRD_PARTY_YGGDRASIL_DECISION_FORESTS_UTILS_GZIP_H_
Loading

0 comments on commit 3843a7a

Please sign in to comment.