diff --git a/cpp/include/lance/arrow/dataset.h b/cpp/include/lance/arrow/dataset.h index c49f42ac1a..36370b52ed 100644 --- a/cpp/include/lance/arrow/dataset.h +++ b/cpp/include/lance/arrow/dataset.h @@ -132,6 +132,16 @@ class LanceDataset : public ::arrow::dataset::Dataset { ::arrow::Result> ReplaceSchema( std::shared_ptr<::arrow::Schema> schema) const override; + /// Add column via a compute expression. + /// + /// \param field the new field. + /// \param expression the expression to compute the column. + /// \return a new version of the dataset. + /// + /// See `Updater` for details. + ::arrow::Result> AddColumn( + const std::shared_ptr<::arrow::Field>& field, ::arrow::compute::Expression expression); + protected: ::arrow::Result<::arrow::dataset::FragmentIterator> GetFragmentsImpl( ::arrow::compute::Expression predicate) override; diff --git a/cpp/include/lance/arrow/updater.h b/cpp/include/lance/arrow/updater.h index 0fe3083192..0c9df6c518 100644 --- a/cpp/include/lance/arrow/updater.h +++ b/cpp/include/lance/arrow/updater.h @@ -20,6 +20,8 @@ #include #include +#include +#include #include "lance/arrow/dataset.h" @@ -73,10 +75,13 @@ class Updater { /// /// \param dataset The dataset to be updated. /// \param field the (new) column to update. + /// \param projection_columns the columns to read from source dataset. /// /// \return an Updater if success. static ::arrow::Result> Make( - std::shared_ptr dataset, const std::shared_ptr<::arrow::Field>& field); + std::shared_ptr dataset, + const std::shared_ptr<::arrow::Field>& field, + const std::vector& projection_columns); /// PIMPL class Impl; @@ -96,12 +101,17 @@ class UpdaterBuilder { public: UpdaterBuilder(std::shared_ptr dataset, std::shared_ptr<::arrow::Field> field); + /// Set the projection columns from the source dataset. + void Project(std::vector columns); + ::arrow::Result> Finish(); private: std::shared_ptr dataset_; std::shared_ptr<::arrow::Field> field_; + + std::vector projection_columns_; }; } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/dataset.cc b/cpp/src/lance/arrow/dataset.cc index 8323a05d1c..f4a5c98816 100644 --- a/cpp/src/lance/arrow/dataset.cc +++ b/cpp/src/lance/arrow/dataset.cc @@ -15,6 +15,7 @@ #include "lance/arrow/dataset.h" #include +#include #include #include #include @@ -316,6 +317,39 @@ ::arrow::Result> LanceDataset::NewUpdate( std::move(new_field)); } +::arrow::Result> LanceDataset::AddColumn( + const std::shared_ptr<::arrow::Field>& field, ::arrow::compute::Expression expression) { + if (!expression.IsScalarExpression()) { + return ::arrow::Status::Invalid( + "LanceDataset::AddColumn: expression is not a scalar expression."); + } + ARROW_ASSIGN_OR_RAISE(expression, expression.Bind(*schema())); + ARROW_ASSIGN_OR_RAISE(auto builder, NewUpdate(field)); + ARROW_ASSIGN_OR_RAISE(auto updater, builder->Finish()); + + // TODO: add projection via FieldRef. + while (true) { + ARROW_ASSIGN_OR_RAISE(auto batch, updater->Next()); + if (!batch) { + break; + } + + ARROW_ASSIGN_OR_RAISE(auto datum, + ::arrow::compute::ExecuteScalarExpression(expression, *schema(), batch)); + std::shared_ptr<::arrow::Array> arr; + if (datum.is_scalar()) { + ARROW_ASSIGN_OR_RAISE(arr, CreateArray(datum.scalar(), batch->num_rows())); + } else if (datum.is_chunked_array()) { + auto chunked_arr = datum.chunked_array(); + ARROW_ASSIGN_OR_RAISE(arr, ::arrow::Concatenate(chunked_arr->chunks())); + } else { + arr = datum.make_array(); + } + ARROW_RETURN_NOT_OK(updater->UpdateBatch(arr)); + } + return updater->Finish(); +} + ::arrow::Result> LanceDataset::ReplaceSchema( [[maybe_unused]] std::shared_ptr<::arrow::Schema> schema) const { return std::make_shared(*this); diff --git a/cpp/src/lance/arrow/dataset_test.cc b/cpp/src/lance/arrow/dataset_test.cc index 1484b712c8..eb8334287d 100644 --- a/cpp/src/lance/arrow/dataset_test.cc +++ b/cpp/src/lance/arrow/dataset_test.cc @@ -48,7 +48,7 @@ std::string WriteTable(const std::shared_ptr<::arrow::Table>& table) { write_options.base_dir = path; write_options.file_write_options = format->DefaultWriteOptions(); - auto dataset = lance::testing::MakeDataset(table).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); CHECK(lance::arrow::LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()) .ok()); @@ -62,7 +62,7 @@ TEST_CASE("Create new dataset") { ::arrow::field("value", ::arrow::utf8())}), {ids, values}); - auto dataset = lance::testing::MakeDataset(table1).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table1); auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata"; auto format = lance::arrow::LanceFileFormat::Make(); @@ -84,14 +84,14 @@ TEST_CASE("Create new dataset") { {ids, values}); // Version 2 is appending. - dataset = lance::testing::MakeDataset(table2).ValueOrDie(); + dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table2); CHECK(lance::arrow::LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(), lance::arrow::LanceDataset::kAppend) .ok()); // Version 3 is overwriting. - dataset = lance::testing::MakeDataset(table2).ValueOrDie(); + dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table2); CHECK(lance::arrow::LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(), lance::arrow::LanceDataset::kOverwrite) @@ -125,7 +125,7 @@ TEST_CASE("Create new dataset over existing dataset") { auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); auto table = ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids}); - auto dataset = lance::testing::MakeDataset(table).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata"; auto format = lance::arrow::LanceFileFormat::Make(); @@ -150,7 +150,7 @@ TEST_CASE("Dataset append error cases") { auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); auto table = ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids}); - auto dataset = lance::testing::MakeDataset(table).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata"; auto format = lance::arrow::LanceFileFormat::Make(); @@ -177,7 +177,7 @@ TEST_CASE("Dataset append error cases") { auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie(); table = ::arrow::Table::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}), {values}); - dataset = lance::testing::MakeDataset(table).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto status = lance::arrow::LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(), @@ -190,7 +190,7 @@ TEST_CASE("Dataset overwrite error cases") { auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); auto table = ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids}); - auto dataset = lance::testing::MakeDataset(table).ValueOrDie(); + auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto base_uri = lance::testing::MakeTemporaryDir().ValueOrDie() + "/testdata"; auto format = lance::arrow::LanceFileFormat::Make(); @@ -209,7 +209,7 @@ TEST_CASE("Dataset overwrite error cases") { auto values = ToArray({"one", "two", "three", "four", "five"}).ValueOrDie(); table = ::arrow::Table::Make(::arrow::schema({::arrow::field("values", ::arrow::utf8())}), {values}); - dataset = lance::testing::MakeDataset(table).ValueOrDie(); + dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto status = lance::arrow::LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie(), @@ -223,9 +223,7 @@ TEST_CASE("Dataset write dictionary array") { auto dict_indices = ToArray({0, 1, 1, 2, 2, 0}).ValueOrDie(); auto data_type = ::arrow::dictionary(::arrow::int32(), ::arrow::utf8()); auto dict_arr = - ::arrow::DictionaryArray::FromArrays( - data_type, dict_indices, dict_values) - .ValueOrDie(); + ::arrow::DictionaryArray::FromArrays(data_type, dict_indices, dict_values).ValueOrDie(); auto table = ::arrow::Table::Make(::arrow::schema({::arrow::field("dict", data_type)}), {dict_arr}); @@ -234,4 +232,55 @@ TEST_CASE("Dataset write dictionary array") { auto actual = ReadTable(base_uri, 1); CHECK(actual->Equals(*table)); +} + +TEST_CASE("Dataset add column with a constant value") { + auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids}); + auto base_uri = WriteTable(table); + auto actual = ReadTable(base_uri, 1); + + auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); + auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie(); + + auto dataset2 = + dataset + ->AddColumn(::arrow::field("doubles", ::arrow::float64()), ::arrow::compute::literal(0.5)) + .ValueOrDie(); + CHECK(dataset2->version().version() == 2); + auto table2 = dataset2->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); + auto doubles = ToArray({0.5, 0.5, 0.5, 0.5, 0.5}).ValueOrDie(); + auto expected_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("doubles", ::arrow::float64())}), + {ids, doubles}); + CHECK(table2->Equals(*expected_table)); +} + +TEST_CASE("Dataset add column with a function call") { + auto ids = ToArray({1, 2, 3, 4, 5}).ValueOrDie(); + auto table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32())}), {ids}); + auto base_uri = WriteTable(table); + auto actual = ReadTable(base_uri, 1); + + auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); + auto dataset = lance::arrow::LanceDataset::Make(fs, base_uri).ValueOrDie(); + + auto dataset2 = + dataset + ->AddColumn( + ::arrow::field("doubles", ::arrow::float64()), + ::arrow::compute::call( + "add", {::arrow::compute::field_ref("id"), ::arrow::compute::literal(0.5)})) + .ValueOrDie(); + CHECK(dataset2->version().version() == 2); + auto table2 = dataset2->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); + auto doubles = ToArray({1.5, 2.5, 3.5, 4.5, 5.5}).ValueOrDie(); + auto expected_table = + ::arrow::Table::Make(::arrow::schema({::arrow::field("id", ::arrow::int32()), + ::arrow::field("doubles", ::arrow::float64())}), + {ids, doubles}); + CHECK(table2->Equals(*expected_table)); } \ No newline at end of file diff --git a/cpp/src/lance/arrow/updater.cc b/cpp/src/lance/arrow/updater.cc index afd7237a93..8830041a66 100644 --- a/cpp/src/lance/arrow/updater.cc +++ b/cpp/src/lance/arrow/updater.cc @@ -41,11 +41,13 @@ class Updater::Impl { Impl(std::shared_ptr dataset, ::arrow::dataset::FragmentVector fragments, std::shared_ptr full_schema, - std::shared_ptr column_schema) + std::shared_ptr column_schema, + std::vector projection_columns) : dataset_(std::move(dataset)), full_schema_(std::move(full_schema)), column_schema_(std::move(column_schema)), fragments_(std::move(fragments)), + projected_columns_(std::move(projection_columns)), fragment_it_(fragments_.begin()) {} /// Copy constructor @@ -54,6 +56,7 @@ class Updater::Impl { full_schema_(other.full_schema_), column_schema_(other.column_schema_), fragments_(other.fragments_.begin(), other.fragments_.end()), + projected_columns_(other.projected_columns_.begin(), other.projected_columns_.end()), fragment_it_(fragments_.begin()) {} ::arrow::Result> Next(); @@ -78,6 +81,8 @@ class Updater::Impl { /// A copy of fragments. ::arrow::dataset::FragmentVector fragments_; + std::vector projected_columns_; + // Used to store the updated fragments. std::vector> data_fragments_; @@ -103,6 +108,9 @@ ::arrow::Status Updater::Impl::NextFragment() { std::make_unique(column_schema_, std::move(write_options), std::move(output)); ARROW_ASSIGN_OR_RAISE(auto scan_builder, dataset_->NewScan()); + if (!projected_columns_.empty()) { + ARROW_RETURN_NOT_OK(scan_builder->Project(projected_columns_)); + } // ARROW_RETURN_NOT_OK(scan_builder->BatchSize(std::numeric_limits::max())); ARROW_ASSIGN_OR_RAISE(auto scanner, scan_builder->Finish()); ARROW_ASSIGN_OR_RAISE(batch_generator_, (*fragment_it_)->ScanBatchesAsync(scanner->options())); @@ -175,7 +183,9 @@ ::arrow::Result> Updater::Impl::Finish() { Updater::~Updater() {} ::arrow::Result> Updater::Make( - std::shared_ptr dataset, const std::shared_ptr<::arrow::Field>& field) { + std::shared_ptr dataset, + const std::shared_ptr<::arrow::Field>& field, + const std::vector& projection_columns) { auto arrow_schema = ::arrow::schema({field}); ARROW_ASSIGN_OR_RAISE(auto full_schema, dataset->impl_->manifest->schema()->Merge(*arrow_schema)); ARROW_ASSIGN_OR_RAISE(auto column_schema, full_schema->Project(*arrow_schema)); @@ -183,8 +193,11 @@ ::arrow::Result> Updater::Make( // Use vector to make implementation easier. // We can later to use FragmentIterator for datasets with a lot of Fragments. ARROW_ASSIGN_OR_RAISE(auto fragments, fragment_iter.ToVector()); - auto impl = std::make_unique( - std::move(dataset), std::move(fragments), std::move(full_schema), std::move(column_schema)); + auto impl = std::make_unique(std::move(dataset), + std::move(fragments), + std::move(full_schema), + std::move(column_schema), + projection_columns); return std::shared_ptr(new Updater(std::move(impl))); } @@ -202,8 +215,12 @@ UpdaterBuilder::UpdaterBuilder(std::shared_ptr source, std::shared_ptr<::arrow::Field> field) : dataset_(std::move(source)), field_(std::move(field)) {} +void UpdaterBuilder::Project(std::vector columns) { + projection_columns_ = std::move(columns); +} + ::arrow::Result> UpdaterBuilder::Finish() { - return Updater::Make(dataset_, field_); + return Updater::Make(dataset_, field_, projection_columns_); } } // namespace lance::arrow \ No newline at end of file diff --git a/cpp/src/lance/arrow/updater_test.cc b/cpp/src/lance/arrow/updater_test.cc index a0758193cf..f27572c1d7 100644 --- a/cpp/src/lance/arrow/updater_test.cc +++ b/cpp/src/lance/arrow/updater_test.cc @@ -14,6 +14,7 @@ #include "lance/arrow/updater.h" +#include #include #include #include @@ -34,16 +35,20 @@ using lance::arrow::LanceDataset; using lance::arrow::ToArray; namespace fs = std::filesystem; -TEST_CASE("Use updater to update one column") { +std::shared_ptr TestDataset() { auto ints = views::iota(0, 100) | to>(); auto ints_arr = ToArray(ints).ValueOrDie(); - auto schema = arrow::schema({arrow::field("ints", arrow::int32())}); - auto table = arrow::Table::Make(schema, {ints_arr}); + auto strs = ints | views::transform([](auto v) { return fmt::format("{}", v); }) | + to>; + auto strs_arr = ToArray(strs).ValueOrDie(); + auto schema = + arrow::schema({arrow::field("ints", arrow::int32()), ::arrow::field("strs", arrow::utf8())}); + auto table = arrow::Table::Make(schema, {ints_arr, strs_arr}); auto dataset = std::make_shared<::arrow::dataset::InMemoryDataset>(table); auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); - auto dataset_uri = fs::path(lance::testing::MakeTemporaryDir().ValueOrDie()) / "update"; + auto dataset_uri = fs::path(lance::testing::MakeTemporaryDir().ValueOrDie()) / "data"; auto format = lance::arrow::LanceFileFormat::Make(); ::arrow::dataset::FileSystemDatasetWriteOptions write_options; @@ -53,12 +58,15 @@ TEST_CASE("Use updater to update one column") { write_options.file_write_options = format->DefaultWriteOptions(); CHECK(LanceDataset::Write(write_options, dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()) .ok()); - fmt::print("Dataset URI: {}\n", dataset_uri.string()); + return LanceDataset::Make(fs, dataset_uri).ValueOrDie(); +} - auto lance_dataset = LanceDataset::Make(fs, dataset_uri).ValueOrDie(); +TEST_CASE("Use updater to update one column") { + auto lance_dataset = TestDataset(); CHECK(lance_dataset->version().version() == 1); + auto table = lance_dataset->NewScan().ValueOrDie()->Finish().ValueOrDie()->ToTable().ValueOrDie(); - auto updater = lance_dataset->NewUpdate(::arrow::field("strs", arrow::utf8())) + auto updater = lance_dataset->NewUpdate(::arrow::field("values", arrow::utf8())) .ValueOrDie() ->Finish() .ValueOrDie(); @@ -69,7 +77,7 @@ TEST_CASE("Use updater to update one column") { break; } cnt++; - CHECK(batch->schema()->Equals(*table->schema())); + CHECK(batch->schema()->Equals(*lance_dataset->schema())); auto input_arr = batch->GetColumnByName("ints"); auto datum = ::arrow::compute::Cast(input_arr, ::arrow::utf8()).ValueOrDie(); auto output_arr = datum.make_array(); @@ -84,8 +92,42 @@ TEST_CASE("Use updater to update one column") { ToArray(views::iota(0, 100) | views::transform([](auto i) { return fmt::format("{}", i); }) | to>) .ValueOrDie(); - auto expected = arrow::Table::Make( - ::arrow::schema({arrow::field("ints", arrow::int32()), arrow::field("strs", arrow::utf8())}), - {ints_arr, expected_strs_arr}); + auto expected = + arrow::Table::Make(::arrow::schema({arrow::field("ints", arrow::int32()), + arrow::field("strs", arrow::utf8()), + arrow::field("values", arrow::utf8())}), + {table->GetColumnByName("ints"), + table->GetColumnByName("strs"), + std::make_shared<::arrow::ChunkedArray>(expected_strs_arr)}); CHECK(expected->Equals(*actual)); +} + +TEST_CASE("Batch must be consumed before the next iteration") { + auto dataset = TestDataset(); + auto updater = dataset->NewUpdate(::arrow::field("new_col", arrow::boolean())) + .ValueOrDie() + ->Finish() + .ValueOrDie(); + auto batch = updater->Next().ValueOrDie(); + CHECK(batch); + auto result = updater->Next(); + CHECK(!result.ok()); +} + +TEST_CASE("Test update with projection") { + auto dataset = TestDataset(); + auto builder = dataset->NewUpdate(::arrow::field("new_col", arrow::utf8())).ValueOrDie(); + builder->Project({"ints"}); + auto updater = builder->Finish().ValueOrDie(); + while (true) { + auto batch = updater->Next().ValueOrDie(); + if (!batch) { + break; + } + CHECK(batch->schema()->Equals(*::arrow::schema({::arrow::field("ints", arrow::int32())}))); + auto input_arr = batch->GetColumnByName("ints"); + auto datum = ::arrow::compute::Cast(input_arr, ::arrow::utf8()).ValueOrDie(); + auto output_arr = datum.make_array(); + CHECK(updater->UpdateBatch(output_arr).ok()); + } } \ No newline at end of file diff --git a/cpp/src/lance/arrow/utils.cc b/cpp/src/lance/arrow/utils.cc index 085f4c4043..3456accd3c 100644 --- a/cpp/src/lance/arrow/utils.cc +++ b/cpp/src/lance/arrow/utils.cc @@ -14,6 +14,7 @@ #include "lance/arrow/utils.h" +#include #include #include #include @@ -272,6 +273,68 @@ ::arrow::Result> OpenDatase return std::dynamic_pointer_cast<::arrow::dataset::FileSystemDataset>(dataset); } +template +::arrow::Result> CreateArrayImpl( + const std::shared_ptr<::arrow::Scalar>& scalar, int64_t length, ::arrow::MemoryPool* pool) { + auto concrete_scalar = + std::dynamic_pointer_cast::ScalarType>(scalar); + auto builder = + std::make_shared::BuilderType>(scalar->type, pool); + ARROW_RETURN_NOT_OK(builder->Reserve(length)); + for (int64_t i = 0; i < length; i++) { + ARROW_RETURN_NOT_OK(builder->Append(concrete_scalar->value)); + } + return builder->Finish(); +} + +template <> +::arrow::Result> CreateArrayImpl<::arrow::StringType>( + const std::shared_ptr<::arrow::Scalar>& scalar, int64_t length, ::arrow::MemoryPool* pool) { + auto concrete_scalar = std::dynamic_pointer_cast<::arrow::StringScalar>(scalar); + auto builder = ::arrow::StringBuilder(pool); + ARROW_RETURN_NOT_OK(builder.Reserve(length)); + for (int64_t i = 0; i < length; i++) { + ARROW_RETURN_NOT_OK(builder.Append(concrete_scalar->view())); + } + return builder.Finish(); +} + +::arrow::Result> CreateArray( + const std::shared_ptr<::arrow::Scalar>& scalar, int64_t length, ::arrow::MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(auto builder, GetArrayBuilder(scalar->type, pool)); + switch (scalar->type->id()) { + case ::arrow::Type::BOOL: + return CreateArrayImpl<::arrow::BooleanType>(scalar, length, pool); + case ::arrow::Type::UINT8: + return CreateArrayImpl<::arrow::UInt8Type>(scalar, length, pool); + case ::arrow::Type::INT8: + return CreateArrayImpl<::arrow::Int8Type>(scalar, length, pool); + case ::arrow::Type::UINT16: + return CreateArrayImpl<::arrow::UInt16Type>(scalar, length, pool); + case ::arrow::Type::INT16: + return CreateArrayImpl<::arrow::Int16Type>(scalar, length, pool); + case ::arrow::Type::UINT32: + return CreateArrayImpl<::arrow::UInt32Type>(scalar, length, pool); + case ::arrow::Type::INT32: + return CreateArrayImpl<::arrow::Int32Type>(scalar, length, pool); + case ::arrow::Type::UINT64: + return CreateArrayImpl<::arrow::UInt64Type>(scalar, length, pool); + case ::arrow::Type::INT64: + return CreateArrayImpl<::arrow::Int64Type>(scalar, length, pool); + case ::arrow::Type::HALF_FLOAT: + return CreateArrayImpl<::arrow::HalfFloatType>(scalar, length, pool); + case ::arrow::Type::FLOAT: + return CreateArrayImpl<::arrow::FloatType>(scalar, length, pool); + case ::arrow::Type::DOUBLE: + return CreateArrayImpl<::arrow::DoubleType>(scalar, length, pool); + case ::arrow::Type::STRING: + return CreateArrayImpl<::arrow::StringType>(scalar, length, pool); + default: + return ::arrow::Status::Invalid( + fmt::format("CreateArray: unsupported type: {}", scalar->type->ToString())); + } +} + namespace { class UuidGenerator { diff --git a/cpp/src/lance/arrow/utils.h b/cpp/src/lance/arrow/utils.h index 8375026092..5c615771ff 100644 --- a/cpp/src/lance/arrow/utils.h +++ b/cpp/src/lance/arrow/utils.h @@ -55,6 +55,17 @@ ::arrow::Result> MergeSchema(const ::arrow::Sch ::arrow::Result> OpenDataset( const std::string& uri, std::shared_ptr<::arrow::dataset::Partitioning> partitioning = nullptr); +/// Create an array from a scalar value. +/// +/// \param scalar the value of each element in the array +/// \param length array length. +/// \param pool memory pool +/// \return +::arrow::Result> CreateArray( + const std::shared_ptr<::arrow::Scalar>& scalar, + int64_t length, + ::arrow::MemoryPool* pool = ::arrow::default_memory_pool()); + /// Get UUID string. std::string GetUUIDString();