Skip to content

Commit

Permalink
GH-17682: [C++][Python] Bool8 Extension Type Implementation (#43488)
Browse files Browse the repository at this point in the history
### Rationale for this change

C++ and Python implementations of #43234

### What changes are included in this PR?

- Implement C++ `Bool8Type`, `Bool8Array`, `Bool8Scalar`, and tests
- Implement Python bindings to C++, as well as zero-copy numpy conversion methods
- TODO: docs waiting for rebase on #43458

### Are these changes tested?

Yes

### Are there any user-facing changes?

Bool8 extension type will be available in C++ and Python libraries

* GitHub Issue: #17682

Authored-by: Joel Lubinitsky <joellubi@gmail.com>
Signed-off-by: Felipe Oliveira Carvalho <felipekde@gmail.com>
  • Loading branch information
joellubi authored Aug 21, 2024
1 parent cc3c868 commit 5258819
Show file tree
Hide file tree
Showing 15 changed files with 604 additions and 7 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ endif()

if(ARROW_JSON)
arrow_add_object_library(ARROW_JSON
extension/bool8.cc

This comment has been minimized.

Copy link
@rok

rok Aug 21, 2024

Member

@joellubi Looking at bool8.cc it doesn't seem to use JSON, so bool8 could be available even when compiled with ARROW_JSON=false. I've tested this on the UUID PR and CI seems to be ok with it. Please let me know if I'm missing something.

This comment has been minimized.

Copy link
@joellubi

joellubi Aug 21, 2024

Author Member

Hi @rok. You're correct, bool8 shouldn't require JSON. This was an oversight on my part. Please feel free to include the change in your PR!

extension/fixed_shape_tensor.cc
extension/opaque.cc
json/options.cc
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
# specific language governing permissions and limitations
# under the License.

add_arrow_test(test
SOURCES
bool8_test.cc
PREFIX
"arrow-extension-bool8")

add_arrow_test(test
SOURCES
fixed_shape_tensor_test.cc
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/extension/bool8.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <sstream>

#include "arrow/extension/bool8.h"
#include "arrow/util/logging.h"

namespace arrow::extension {

bool Bool8Type::ExtensionEquals(const ExtensionType& other) const {
return extension_name() == other.extension_name();
}

std::string Bool8Type::ToString(bool show_metadata) const {
std::stringstream ss;
ss << "extension<" << this->extension_name() << ">";
return ss.str();
}

std::string Bool8Type::Serialize() const { return ""; }

Result<std::shared_ptr<DataType>> Bool8Type::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
if (storage_type->id() != Type::INT8) {
return Status::Invalid("Expected INT8 storage type, got ", storage_type->ToString());
}
if (serialized_data != "") {
return Status::Invalid("Serialize data must be empty, got ", serialized_data);
}
return bool8();
}

std::shared_ptr<Array> Bool8Type::MakeArray(std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("arrow.bool8",
internal::checked_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<Bool8Array>(data);
}

Result<std::shared_ptr<DataType>> Bool8Type::Make() {
return std::make_shared<Bool8Type>();
}

std::shared_ptr<DataType> bool8() { return std::make_shared<Bool8Type>(); }

} // namespace arrow::extension
58 changes: 58 additions & 0 deletions cpp/src/arrow/extension/bool8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension_type.h"

namespace arrow::extension {

/// \brief Bool8 is an alternate representation for boolean
/// arrays using 8 bits instead of 1 bit per value. The underlying
/// storage type is int8.
class ARROW_EXPORT Bool8Array : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

/// \brief Bool8 is an alternate representation for boolean
/// arrays using 8 bits instead of 1 bit per value. The underlying
/// storage type is int8.
class ARROW_EXPORT Bool8Type : public ExtensionType {
public:
/// \brief Construct a Bool8Type.
Bool8Type() : ExtensionType(int8()) {}

std::string extension_name() const override { return "arrow.bool8"; }
std::string ToString(bool show_metadata = false) const override;

bool ExtensionEquals(const ExtensionType& other) const override;

std::string Serialize() const override;

Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized_data) const override;

/// Create a Bool8Array from ArrayData
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

static Result<std::shared_ptr<DataType>> Make();
};

/// \brief Return a Bool8Type instance.
ARROW_EXPORT std::shared_ptr<DataType> bool8();

} // namespace arrow::extension
91 changes: 91 additions & 0 deletions cpp/src/arrow/extension/bool8_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension/bool8.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/ipc/writer.h"
#include "arrow/testing/extension_type.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {

TEST(Bool8Type, Basics) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
auto type2 = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
ASSERT_EQ("arrow.bool8", type->extension_name());
ASSERT_EQ(*type, *type);
ASSERT_NE(*arrow::null(), *type);
ASSERT_EQ(*type, *type2);
ASSERT_EQ(*arrow::int8(), *type->storage_type());
ASSERT_EQ("", type->Serialize());
ASSERT_EQ("extension<arrow.bool8>", type->ToString(false));
}

TEST(Bool8Type, CreateFromArray) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]");
auto array = ExtensionType::WrapArray(type, storage);
ASSERT_EQ(5, array->length());
ASSERT_EQ(1, array->null_count());
}

TEST(Bool8Type, Deserialize) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
ASSERT_OK_AND_ASSIGN(auto deserialized, type->Deserialize(type->storage_type(), ""));
ASSERT_EQ(*type, *deserialized);
ASSERT_NOT_OK(type->Deserialize(type->storage_type(), "must be empty"));
ASSERT_EQ(*type, *deserialized);
ASSERT_NOT_OK(type->Deserialize(uint8(), ""));
ASSERT_EQ(*type, *deserialized);
}

TEST(Bool8Type, MetadataRoundTrip) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());
std::string serialized = type->Serialize();
ASSERT_OK_AND_ASSIGN(auto deserialized,
type->Deserialize(type->storage_type(), serialized));
ASSERT_EQ(*type, *deserialized);
}

TEST(Bool8Type, BatchRoundTrip) {
auto type = internal::checked_pointer_cast<extension::Bool8Type>(extension::bool8());

auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]");
auto array = ExtensionType::WrapArray(type, storage);
auto batch =
RecordBatch::Make(schema({field("field", type)}), array->length(), {array});

std::shared_ptr<RecordBatch> written;
{
ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create());
ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(),
out_stream.get()));

ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish());

io::BufferReader reader(complete_ipc_stream);
std::shared_ptr<RecordBatchReader> batch_reader;
ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader));
ASSERT_OK(batch_reader->ReadNext(&written));
}

ASSERT_EQ(*batch->schema(), *written->schema());
ASSERT_BATCHES_EQUAL(*batch, *written);
}

} // namespace arrow
7 changes: 5 additions & 2 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/chunked_array.h"
#include "arrow/config.h"
#ifdef ARROW_JSON
#include "arrow/extension/bool8.h"
#include "arrow/extension/fixed_shape_tensor.h"
#endif
#include "arrow/status.h"
Expand Down Expand Up @@ -146,10 +147,12 @@ static void CreateGlobalRegistry() {

#ifdef ARROW_JSON
// Register canonical extension types
auto ext_type =
auto fst_ext_type =
checked_pointer_cast<ExtensionType>(extension::fixed_shape_tensor(int64(), {}));
ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type));

ARROW_CHECK_OK(g_registry->RegisterType(ext_type));
auto bool8_ext_type = checked_pointer_cast<ExtensionType>(extension::bool8());
ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type));
#endif
}

Expand Down
7 changes: 4 additions & 3 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def print_entry(label, value):
run_end_encoded,
fixed_shape_tensor,
opaque,
bool8,
field,
type_for_alias,
DataType, DictionaryType, StructType,
Expand All @@ -184,7 +185,7 @@ def print_entry(label, value):
FixedSizeBinaryType, Decimal128Type, Decimal256Type,
BaseExtensionType, ExtensionType,
RunEndEncodedType, FixedShapeTensorType, OpaqueType,
PyExtensionType, UnknownExtensionType,
Bool8Type, PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
KeyValueMetadata,
Expand Down Expand Up @@ -218,7 +219,7 @@ def print_entry(label, value):
MonthDayNanoIntervalArray,
Decimal128Array, Decimal256Array, StructArray, ExtensionArray,
RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray,
scalar, NA, _NULL as NULL, Scalar,
Bool8Array, scalar, NA, _NULL as NULL, Scalar,
NullScalar, BooleanScalar,
Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar,
UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar,
Expand All @@ -235,7 +236,7 @@ def print_entry(label, value):
FixedSizeBinaryScalar, DictionaryScalar,
MapScalar, StructScalar, UnionScalar,
RunEndEncodedScalar, ExtensionScalar,
FixedShapeTensorScalar, OpaqueScalar)
FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar)

# Buffers, allocation
from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager,
Expand Down
Loading

0 comments on commit 5258819

Please sign in to comment.