Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use column projection during update #322

Merged
merged 10 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions cpp/include/lance/arrow/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ class LanceDataset : public ::arrow::dataset::Dataset {
::arrow::Result<std::shared_ptr<::arrow::dataset::Dataset>> ReplaceSchema(
std::shared_ptr<::arrow::Schema> schema) const override;

/// Add column via a compute expression.
///
/// \param field the new field.
/// \param expression the expression to compute the column.
/// \return a new version of the dataset.
///
/// See `Updater` for details.
::arrow::Result<std::shared_ptr<LanceDataset>> AddColumn(
const std::shared_ptr<::arrow::Field>& field, ::arrow::compute::Expression expression);

protected:
::arrow::Result<::arrow::dataset::FragmentIterator> GetFragmentsImpl(
::arrow::compute::Expression predicate) override;
Expand Down
12 changes: 11 additions & 1 deletion cpp/include/lance/arrow/updater.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <arrow/status.h>

#include <memory>
#include <string>
#include <vector>

#include "lance/arrow/dataset.h"

Expand Down Expand Up @@ -73,10 +75,13 @@ class Updater {
///
/// \param dataset The dataset to be updated.
/// \param field the (new) column to update.
/// \param projection_columns the columns to read from source dataset.
///
/// \return an Updater if success.
static ::arrow::Result<std::shared_ptr<Updater>> Make(
std::shared_ptr<LanceDataset> dataset, const std::shared_ptr<::arrow::Field>& field);
std::shared_ptr<LanceDataset> dataset,
const std::shared_ptr<::arrow::Field>& field,
const std::vector<std::string>& projection_columns);

/// PIMPL
class Impl;
Expand All @@ -96,12 +101,17 @@ class UpdaterBuilder {
public:
UpdaterBuilder(std::shared_ptr<LanceDataset> dataset, std::shared_ptr<::arrow::Field> field);

/// Set the projection columns from the source dataset.
void Project(std::vector<std::string> columns);

::arrow::Result<std::shared_ptr<Updater>> Finish();

private:
std::shared_ptr<LanceDataset> dataset_;

std::shared_ptr<::arrow::Field> field_;

std::vector<std::string> projection_columns_;
};

} // namespace lance::arrow
34 changes: 34 additions & 0 deletions cpp/src/lance/arrow/dataset.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "lance/arrow/dataset.h"

#include <arrow/array.h>
#include <arrow/array/concatenate.h>
#include <arrow/dataset/api.h>
#include <arrow/status.h>
#include <arrow/table.h>
Expand Down Expand Up @@ -316,6 +317,39 @@ ::arrow::Result<std::shared_ptr<UpdaterBuilder>> LanceDataset::NewUpdate(
std::move(new_field));
}

::arrow::Result<std::shared_ptr<LanceDataset>> LanceDataset::AddColumn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again the problem here is if the compute expression contains aggregates

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be checked via bool Expression::IsScalarExpression() const.

I can throw a invalid status from AddColumn

const std::shared_ptr<::arrow::Field>& field, ::arrow::compute::Expression expression) {
if (!expression.IsScalarExpression()) {
return ::arrow::Status::Invalid(
"LanceDataset::AddColumn: expression is not a scalar expression.");
}
ARROW_ASSIGN_OR_RAISE(expression, expression.Bind(*schema()));
ARROW_ASSIGN_OR_RAISE(auto builder, NewUpdate(field));
ARROW_ASSIGN_OR_RAISE(auto updater, builder->Finish());

// TODO: add projection via FieldRef.
while (true) {
ARROW_ASSIGN_OR_RAISE(auto batch, updater->Next());
if (!batch) {
break;
}

ARROW_ASSIGN_OR_RAISE(auto datum,
::arrow::compute::ExecuteScalarExpression(expression, *schema(), batch));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this bind the expression or is it required to be bound before the AddColumn method is called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::shared_ptr<::arrow::Array> arr;
if (datum.is_scalar()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so this is a constant literal value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is for case like AddColumn(field, pc::literal(1234)).

ARROW_ASSIGN_OR_RAISE(arr, CreateArray(datum.scalar(), batch->num_rows()));
} else if (datum.is_chunked_array()) {
auto chunked_arr = datum.chunked_array();
ARROW_ASSIGN_OR_RAISE(arr, ::arrow::Concatenate(chunked_arr->chunks()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ExtensionArray's cannot be concatenated currently - tho compute expressions won't either so ExtensionArray's probably won't make it past ExecuteScalarExpression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a test as follow up? But also, there is no function / kernel is available for extension types right now, this method might fail earlier.

} else {
arr = datum.make_array();
}
ARROW_RETURN_NOT_OK(updater->UpdateBatch(arr));
}
return updater->Finish();
}

::arrow::Result<std::shared_ptr<::arrow::dataset::Dataset>> LanceDataset::ReplaceSchema(
[[maybe_unused]] std::shared_ptr<::arrow::Schema> schema) const {
return std::make_shared<LanceDataset>(*this);
Expand Down
73 changes: 61 additions & 12 deletions cpp/src/lance/arrow/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ std::string WriteTable(const std::shared_ptr<::arrow::Table>& table) {
write_options.base_dir = path;
write_options.file_write_options = format->DefaultWriteOptions();

auto dataset = lance::testing::MakeDataset(table).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
CHECK(lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie())
.ok());
Expand All @@ -62,7 +62,7 @@ TEST_CASE("Create new dataset") {
::arrow::field("value", ::arrow::utf8())}),
{ids, values});

auto dataset = lance::testing::MakeDataset(table1).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table1);

auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata";
auto format = lance::arrow::LanceFileFormat::Make();
Expand All @@ -84,14 +84,14 @@ TEST_CASE("Create new dataset") {
{ids, values});

// Version 2 is appending.
dataset = lance::testing::MakeDataset(table2).ValueOrDie();
dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table2);
CHECK(lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(),
lance::arrow::LanceDataset::kAppend)
.ok());

// Version 3 is overwriting.
dataset = lance::testing::MakeDataset(table2).ValueOrDie();
dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table2);
CHECK(lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(),
lance::arrow::LanceDataset::kOverwrite)
Expand Down Expand Up @@ -125,7 +125,7 @@ TEST_CASE("Create new dataset over existing dataset") {
auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids});
auto dataset = lance::testing::MakeDataset(table).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);

auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata";
auto format = lance::arrow::LanceFileFormat::Make();
Expand All @@ -150,7 +150,7 @@ TEST_CASE("Dataset append error cases") {
auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids});
auto dataset = lance::testing::MakeDataset(table).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);

auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata";
auto format = lance::arrow::LanceFileFormat::Make();
Expand All @@ -177,7 +177,7 @@ TEST_CASE("Dataset append error cases") {
auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie();
table = ::arrow::Table::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}),
{values});
dataset = lance::testing::MakeDataset(table).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto status =
lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(),
Expand All @@ -190,7 +190,7 @@ TEST_CASE("Dataset overwrite error cases") {
auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids});
auto dataset = lance::testing::MakeDataset(table).ValueOrDie();
auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);

auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata";
auto format = lance::arrow::LanceFileFormat::Make();
Expand All @@ -209,7 +209,7 @@ TEST_CASE("Dataset overwrite error cases") {
auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie();
table =
::arrow::Table::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}), {values});
dataset = lance::testing::MakeDataset(table).ValueOrDie();
dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table);
auto status =
lance::arrow::LanceDataset::Write(write_options,
dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(),
Expand All @@ -223,9 +223,7 @@ TEST_CASE("Dataset write dictionary array") {
auto dict_indices = ToArray({0, 1, 1, 2, 2, 0}).ValueOrDie();
auto data_type = ::arrow::dictionary(::arrow::int32(), ::arrow::utf8());
auto dict_arr =
::arrow::DictionaryArray::FromArrays(
data_type, dict_indices, dict_values)
.ValueOrDie();
::arrow::DictionaryArray::FromArrays(data_type, dict_indices, dict_values).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("dict", data_type)}), {dict_arr});

Expand All @@ -234,4 +232,55 @@ TEST_CASE("Dataset write dictionary array") {

auto actual = ReadTable(base_uri, 1);
CHECK(actual->Equals(*table));
}

TEST_CASE("Dataset add column with a constant value") {
auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids});
auto base_uri = WriteTable(table);
auto actual = ReadTable(base_uri, 1);

auto fs = std::make_shared<::arrow::fs::LocalFileSystem>();
auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie();

auto dataset2 =
dataset
->AddColumn(::arrow::field("doubles", ::arrow::float64()), ::arrow::compute::literal(0.5))
.ValueOrDie();
CHECK(dataset2->version().version() == 2);
auto table2 = dataset2->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie();
auto doubles = ToArray<double>({0.5, 0.5, 0.5, 0.5, 0.5}).ValueOrDie();
auto expected_table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()),
::arrow::field("doubles", ::arrow::float64())}),
{ids, doubles});
CHECK(table2->Equals(*expected_table));
}

TEST_CASE("Dataset add column with a function call") {
auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie();
auto table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids});
auto base_uri = WriteTable(table);
auto actual = ReadTable(base_uri, 1);

auto fs = std::make_shared<::arrow::fs::LocalFileSystem>();
auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie();

auto dataset2 =
dataset
->AddColumn(
::arrow::field("doubles", ::arrow::float64()),
::arrow::compute::call(
"add", {::arrow::compute::field_ref("id"), ::arrow::compute::literal(0.5)}))
.ValueOrDie();
CHECK(dataset2->version().version() == 2);
auto table2 = dataset2->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie();
auto doubles = ToArray<double>({1.5, 2.5, 3.5, 4.5, 5.5}).ValueOrDie();
auto expected_table =
::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()),
::arrow::field("doubles", ::arrow::float64())}),
{ids, doubles});
CHECK(table2->Equals(*expected_table));
}
27 changes: 22 additions & 5 deletions cpp/src/lance/arrow/updater.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ class Updater::Impl {
Impl(std::shared_ptr<LanceDataset> dataset,
::arrow::dataset::FragmentVector fragments,
std::shared_ptr<lance::format::Schema> full_schema,
std::shared_ptr<lance::format::Schema> column_schema)
std::shared_ptr<lance::format::Schema> column_schema,
std::vector<std::string> projection_columns)
: dataset_(std::move(dataset)),
full_schema_(std::move(full_schema)),
column_schema_(std::move(column_schema)),
fragments_(std::move(fragments)),
projected_columns_(std::move(projection_columns)),
fragment_it_(fragments_.begin()) {}

/// Copy constructor
Expand All @@ -54,6 +56,7 @@ class Updater::Impl {
full_schema_(other.full_schema_),
column_schema_(other.column_schema_),
fragments_(other.fragments_.begin(), other.fragments_.end()),
projected_columns_(other.projected_columns_.begin(), other.projected_columns_.end()),
fragment_it_(fragments_.begin()) {}

::arrow::Result<std::shared_ptr<::arrow::RecordBatch>> Next();
Expand All @@ -78,6 +81,8 @@ class Updater::Impl {
/// A copy of fragments.
::arrow::dataset::FragmentVector fragments_;

std::vector<std::string> projected_columns_;

// Used to store the updated fragments.
std::vector<std::shared_ptr<format::DataFragment>> data_fragments_;

Expand All @@ -103,6 +108,9 @@ ::arrow::Status Updater::Impl::NextFragment() {
std::make_unique<io::FileWriter>(column_schema_, std::move(write_options), std::move(output));

ARROW_ASSIGN_OR_RAISE(auto scan_builder, dataset_->NewScan());
if (!projected_columns_.empty()) {
ARROW_RETURN_NOT_OK(scan_builder->Project(projected_columns_));
}
// ARROW_RETURN_NOT_OK(scan_builder->BatchSize(std::numeric_limits<int64_t>::max()));
ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish());
ARROW_ASSIGN_OR_RAISE(batch_generator_, (*fragment_it_)->ScanBatchesAsync(scanner->options()));
Expand Down Expand Up @@ -175,16 +183,21 @@ ::arrow::Result<std::shared_ptr<LanceDataset>> Updater::Impl::Finish() {
Updater::~Updater() {}

::arrow::Result<std::shared_ptr<Updater>> Updater::Make(
std::shared_ptr<LanceDataset> dataset, const std::shared_ptr<::arrow::Field>& field) {
std::shared_ptr<LanceDataset> dataset,
const std::shared_ptr<::arrow::Field>& field,
const std::vector<std::string>& projection_columns) {
auto arrow_schema = ::arrow::schema({field});
ARROW_ASSIGN_OR_RAISE(auto full_schema, dataset->impl_->manifest->schema()->Merge(*arrow_schema));
ARROW_ASSIGN_OR_RAISE(auto column_schema, full_schema->Project(*arrow_schema));
ARROW_ASSIGN_OR_RAISE(auto fragment_iter, dataset->GetFragments());
// Use vector to make implementation easier.
// We can later to use FragmentIterator for datasets with a lot of Fragments.
ARROW_ASSIGN_OR_RAISE(auto fragments, fragment_iter.ToVector());
auto impl = std::make_unique<Impl>(
std::move(dataset), std::move(fragments), std::move(full_schema), std::move(column_schema));
auto impl = std::make_unique<Impl>(std::move(dataset),
std::move(fragments),
std::move(full_schema),
std::move(column_schema),
projection_columns);
return std::shared_ptr<Updater>(new Updater(std::move(impl)));
}

Expand All @@ -202,8 +215,12 @@ UpdaterBuilder::UpdaterBuilder(std::shared_ptr<LanceDataset> source,
std::shared_ptr<::arrow::Field> field)
: dataset_(std::move(source)), field_(std::move(field)) {}

void UpdaterBuilder::Project(std::vector<std::string> columns) {
projection_columns_ = std::move(columns);
}

::arrow::Result<std::shared_ptr<Updater>> UpdaterBuilder::Finish() {
return Updater::Make(dataset_, field_);
return Updater::Make(dataset_, field_, projection_columns_);
}

} // namespace lance::arrow
Loading