Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-14999: [C++] Optional field name equality checks for map and list type #14847

Merged
merged 4 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/memory.h"
Expand Down Expand Up @@ -559,6 +560,14 @@ class TypeEqualsVisitor {
explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
: right_(right), check_metadata_(check_metadata), result_(false) {}

bool MetadataEqual(const Field& left, const Field& right) {
if (left.HasMetadata() && right.HasMetadata()) {
return left.metadata()->Equals(*right.metadata());
} else {
return !left.HasMetadata() && !right.HasMetadata();
}
}

Status VisitChildren(const DataType& left) {
if (left.num_fields() != right_.num_fields()) {
result_ = false;
Expand Down Expand Up @@ -626,8 +635,21 @@ class TypeEqualsVisitor {
}

template <typename T>
enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit(
const T& left) {
enable_if_t<is_list_like_type<T>::value, Status> Visit(const T& left) {
std::shared_ptr<Field> left_field = left.field(0);
std::shared_ptr<Field> right_field = checked_cast<const T&>(right_).field(0);
bool equal_names = !check_metadata_ || (left_field->name() == right_field->name());
bool equal_metadata = !check_metadata_ || MetadataEqual(*left_field, *right_field);

result_ = equal_names && equal_metadata &&
(left_field->nullable() == right_field->nullable()) &&
left_field->type()->Equals(*right_field->type(), check_metadata_);

return Status::OK();
}

template <typename T>
enable_if_t<is_struct_type<T>::value, Status> Visit(const T& left) {
return VisitChildren(left);
}

Expand All @@ -637,6 +659,18 @@ class TypeEqualsVisitor {
result_ = false;
return Status::OK();
}
if (check_metadata_ && (left.item_field()->name() != right.item_field()->name() ||
left.key_field()->name() != right.key_field()->name() ||
left.value_field()->name() != right.value_field()->name())) {
result_ = false;
return Status::OK();
}
if (check_metadata_ && !(MetadataEqual(*left.item_field(), *right.item_field()) &&
MetadataEqual(*left.key_field(), *right.key_field()) &&
MetadataEqual(*left.value_field(), *right.value_field()))) {
result_ = false;
return Status::OK();
}
result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) &&
left.item_type()->Equals(*right.item_type(), check_metadata_);
return Status::OK();
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ class CastMetaFunction : public MetaFunction {
if (!is_nested(args[0].type()->id())) {
return args[0];
} else if (args[0].is_array()) {
// TODO(ARROW-14999): if types are equal except for field names of list
// types, we can also use this code path.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> array,
::arrow::internal::GetArrayView(
args[0].array(), cast_options->to_type.owned_type));
Expand Down
53 changes: 43 additions & 10 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,11 +412,11 @@ bool DataType::Equals(const DataType& other, bool check_metadata) const {
return TypeEquals(*this, other, check_metadata);
}

bool DataType::Equals(const std::shared_ptr<DataType>& other) const {
bool DataType::Equals(const std::shared_ptr<DataType>& other, bool check_metadata) const {
if (!other) {
return false;
}
return Equals(*other.get());
return Equals(*other.get(), check_metadata);
}

size_t DataType::Hash() const {
Expand Down Expand Up @@ -2090,6 +2090,10 @@ std::string DataType::ComputeMetadataFingerprint() const {
// Whatever the data type, metadata can only be found on child fields
std::string s;
for (const auto& child : children_) {
// Add field name to metadata fingerprint so that the field names within
// list and map types are included as part of the metadata. They are
// excluded from the base fingerprint.
s += child->name() + "=";
wjones127 marked this conversation as resolved.
Show resolved Hide resolved
s += child->metadata_fingerprint() + ";";
}
return s;
Expand Down Expand Up @@ -2136,17 +2140,33 @@ std::string DictionaryType::ComputeFingerprint() const {
}

std::string ListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
ss << '{' << child_fingerprint << '}';
return ss.str();
}
return "";
}

std::string LargeListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
ss << '{' << child_fingerprint << '}';
return ss.str();
}
return "";
}
Expand All @@ -2155,20 +2175,33 @@ std::string MapType::ComputeFingerprint() const {
const auto& key_fingerprint = key_type()->fingerprint();
const auto& item_fingerprint = item_type()->fingerprint();
if (!key_fingerprint.empty() && !item_fingerprint.empty()) {
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (keys_sorted_) {
return TypeIdFingerprint(*this) + "s{" + key_fingerprint + item_fingerprint + "}";
ss << 's';
}
if (item_field()->nullable()) {
ss << 'n';
} else {
return TypeIdFingerprint(*this) + "{" + key_fingerprint + item_fingerprint + "}";
ss << 'N';
}
ss << '{' << key_fingerprint + item_fingerprint << '}';
return ss.str();
}
return "";
}

std::string FixedSizeListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
std::stringstream ss;
ss << TypeIdFingerprint(*this) << "[" << list_size_ << "]"
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
ss << "[" << list_size_ << "]"
<< "{" << child_fingerprint << "}";
return ss.str();
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class ARROW_EXPORT DataType : public std::enable_shared_from_this<DataType>,
bool Equals(const DataType& other, bool check_metadata = false) const;

/// \brief Return whether the types are equal
bool Equals(const std::shared_ptr<DataType>& other) const;
bool Equals(const std::shared_ptr<DataType>& other, bool check_metadata = false) const;

/// \brief Return the child field at index i.
const std::shared_ptr<Field>& field(int i) const { return children_[i]; }
Expand Down
60 changes: 54 additions & 6 deletions cpp/src/arrow/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,8 @@ TEST(TestLargeListType, Basics) {
}

TEST(TestMapType, Basics) {
auto md = key_value_metadata({"foo"}, {"foo value"});

std::shared_ptr<DataType> kt = std::make_shared<StringType>();
std::shared_ptr<DataType> it = std::make_shared<UInt8Type>();

Expand Down Expand Up @@ -1294,6 +1296,41 @@ TEST(TestMapType, Basics) {
"some_entries",
struct_({field("some_key", kt, false), field("some_value", mt)}), false)));
AssertTypeEqual(mt3, *mt5);
// ...unless we explicitly ask about them.
ASSERT_FALSE(mt3.Equals(mt5, /*check_metadata=*/true));

// nullability of value type matters in comparisons
MapType map_type_non_nullable(kt, field("value", it, /*nullable=*/false));
AssertTypeNotEqual(map_type, map_type_non_nullable);
}

TEST(TestMapType, Metadata) {
auto md1 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
auto md2 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"});
auto md3 = key_value_metadata({"foo"}, {"foo value"});

auto t1 = map(utf8(), field("value", int32(), md1));
auto t2 = map(utf8(), field("value", int32(), md2));
auto t3 = map(utf8(), field("value", int32(), md3));
auto t4 =
std::make_shared<MapType>(field("key", utf8(), md1), field("value", int32(), md2));
ASSERT_OK_AND_ASSIGN(auto t5,
MapType::Make(field("some_entries",
struct_({field("some_key", utf8(), false),
field("some_value", int32(), md2)}),
false, md2)));

AssertTypeEqual(*t1, *t2);
AssertTypeEqual(*t1, *t2, /*check_metadata=*/true);

AssertTypeEqual(*t1, *t3);
AssertTypeNotEqual(*t1, *t3, /*check_metadata=*/true);

AssertTypeEqual(*t1, *t4);
AssertTypeNotEqual(*t1, *t4, /*check_metadata=*/true);

AssertTypeEqual(*t1, *t5);
AssertTypeNotEqual(*t1, *t5, /*check_metadata=*/true);
}

TEST(TestFixedSizeListType, Basics) {
Expand Down Expand Up @@ -1478,15 +1515,26 @@ TEST(TestListType, Equals) {
auto t1 = list(utf8());
auto t2 = list(utf8());
auto t3 = list(binary());
auto t4 = large_list(binary());
auto t5 = large_list(binary());
auto t6 = large_list(float64());
auto t4 = list(field("item", utf8(), /*nullable=*/false));
auto tl1 = large_list(binary());
auto tl2 = large_list(binary());
auto tl3 = large_list(float64());

AssertTypeEqual(*t1, *t2);
AssertTypeNotEqual(*t1, *t3);
AssertTypeNotEqual(*t3, *t4);
AssertTypeEqual(*t4, *t5);
AssertTypeNotEqual(*t5, *t6);
AssertTypeNotEqual(*t1, *t4);
AssertTypeNotEqual(*t3, *tl1);
AssertTypeEqual(*tl1, *tl2);
AssertTypeNotEqual(*tl2, *tl3);

std::shared_ptr<DataType> vt = std::make_shared<UInt8Type>();
std::shared_ptr<Field> inner_field = std::make_shared<Field>("non_default_name", vt);

ListType list_type(vt);
ListType list_type_named(inner_field);

AssertTypeEqual(list_type, list_type_named);
ASSERT_FALSE(list_type.Equals(list_type_named, /*check_metadata=*/true));
}

TEST(TestListType, Metadata) {
Expand Down
14 changes: 9 additions & 5 deletions java/c/src/test/python/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,18 @@ def round_trip_field(self, field_generator):
expected = field_generator()
self.assertEqual(expected, new_field)

def round_trip_array(self, array_generator, expected_diff=None):
def round_trip_array(self, array_generator, check_metadata=True):
original_arr = array_generator()
with self.bridge.java_c.CDataDictionaryProvider() as dictionary_provider, \
self.bridge.python_to_java_array(original_arr, dictionary_provider) as vector:
del original_arr
new_array = self.bridge.java_to_python_array(vector, dictionary_provider)

expected = array_generator()
if expected_diff:
self.assertEqual(expected, new_array.view(expected.type))
self.assertEqual(expected.diff(new_array), expected_diff or '')

self.assertEqual(expected, new_array)
if check_metadata:
self.assertTrue(new_array.type.equals(expected.type, check_metadata=True))

def round_trip_record_batch(self, rb_generator):
original_rb = rb_generator()
Expand Down Expand Up @@ -191,7 +192,10 @@ def test_int_array(self):
def test_list_array(self):
self.round_trip_array(lambda: pa.array(
[[], [0], [1, 2], [4, 5, 6]], pa.list_(pa.int64())
), "# Array types differed: list<item: int64> vs list<$data$: int64>\n")
# disabled check_metadata since the list internal field name ("item")
# is not preserved during round trips (it becomes "$data$").
), check_metadata=False)
wjones127 marked this conversation as resolved.
Show resolved Hide resolved


def test_struct_array(self):
fields = [
Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
cdef cppclass CDataType" arrow::DataType":
Type id()

c_bool Equals(const CDataType& other)
c_bool Equals(const shared_ptr[CDataType]& other)
c_bool Equals(const CDataType& other, c_bool check_metadata)
c_bool Equals(const shared_ptr[CDataType]& other, c_bool check_metadata)

shared_ptr[CField] field(int i)
const vector[shared_ptr[CField]] fields()
Expand Down
32 changes: 32 additions & 0 deletions python/pyarrow/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,21 @@ def test_list_type():
assert ty.value_type == pa.int64()
assert ty.value_field == pa.field("item", pa.int64(), nullable=True)

# nullability matters in comparison
ty_non_nullable = pa.list_(pa.field("item", pa.int64(), nullable=False))
assert ty != ty_non_nullable

# field names don't matter by default
ty_named = pa.list_(pa.field("element", pa.int64()))
assert ty == ty_named
assert not ty.equals(ty_named, check_metadata=True)

# metadata doesn't matter by default
ty_metadata = pa.list_(
pa.field("item", pa.int64(), metadata={"hello": "world"}))
assert ty == ty_metadata
assert not ty.equals(ty_metadata, check_metadata=True)

with pytest.raises(TypeError):
pa.list_(None)

Expand All @@ -540,6 +555,23 @@ def test_map_type():
assert ty.item_type == pa.int32()
assert ty.item_field == pa.field("value", pa.int32(), nullable=True)

# nullability matters in comparison
ty_non_nullable = pa.map_(pa.utf8(), pa.field(
"value", pa.int32(), nullable=False))
assert ty != ty_non_nullable

# field names don't matter by default
ty_named = pa.map_(pa.field("x", pa.utf8(), nullable=False),
pa.field("y", pa.int32()))
assert ty == ty_named
assert not ty.equals(ty_named, check_metadata=True)

# metadata doesn't matter by default
ty_metadata = pa.map_(pa.utf8(), pa.field(
"value", pa.int32(), metadata={"hello": "world"}))
assert ty == ty_metadata
assert not ty.equals(ty_metadata, check_metadata=True)

with pytest.raises(TypeError):
pa.map_(None)
with pytest.raises(TypeError):
Expand Down
Loading