From 54989df19406420fe469dfec7f374c3fac8f4bb8 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 5 Dec 2022 15:22:16 -0800 Subject: [PATCH 1/4] feat(api) check_metadata should check internal field names --- cpp/src/arrow/compare.cc | 38 ++++++++++++- cpp/src/arrow/compute/cast.cc | 2 - cpp/src/arrow/type.cc | 50 +++++++++++++---- cpp/src/arrow/type.h | 2 +- cpp/src/arrow/type_test.cc | 60 ++++++++++++++++++--- java/c/src/test/python/integration_tests.py | 11 ++-- python/pyarrow/includes/libarrow.pxd | 4 +- python/pyarrow/tests/test_types.py | 32 +++++++++++ python/pyarrow/types.pxi | 13 +++-- r/R/arrowExports.R | 4 +- r/R/type.R | 4 +- r/src/arrowExports.cpp | 9 ++-- r/src/datatype.cpp | 5 +- r/tests/testthat/test-data-type.R | 16 ++++++ r/tests/testthat/test-parquet.R | 5 +- 15 files changed, 210 insertions(+), 45 deletions(-) diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index baadd10cca98b..fa83426ab7f04 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -43,6 +43,7 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/memory.h" @@ -559,6 +560,14 @@ class TypeEqualsVisitor { explicit TypeEqualsVisitor(const DataType& right, bool check_metadata) : right_(right), check_metadata_(check_metadata), result_(false) {} + bool MetadataEqual(const Field& left, const Field& right) { + if (left.HasMetadata() && right.HasMetadata()) { + return left.metadata()->Equals(*right.metadata()); + } else { + return !left.HasMetadata() && !right.HasMetadata(); + } + } + Status VisitChildren(const DataType& left) { if (left.num_fields() != right_.num_fields()) { result_ = false; @@ -626,8 +635,21 @@ class TypeEqualsVisitor { } template - enable_if_t::value || is_struct_type::value, Status> Visit( - const T& left) { + enable_if_t::value, Status> Visit(const T& left) { + std::shared_ptr left_field = left.field(0); + std::shared_ptr right_field = checked_cast(right_).field(0); + bool equal_names = !check_metadata_ || (left_field->name() == right_field->name()); + bool equal_metadata = !check_metadata_ || MetadataEqual(*left_field, *right_field); + + result_ = equal_names && equal_metadata && + (left_field->nullable() == right_field->nullable()) && + left_field->type()->Equals(*right_field->type(), check_metadata_); + + return Status::OK(); + } + + template + enable_if_t::value, Status> Visit(const T& left) { return VisitChildren(left); } @@ -637,6 +659,18 @@ class TypeEqualsVisitor { result_ = false; return Status::OK(); } + if (check_metadata_ && (left.item_field()->name() != right.item_field()->name() || + left.key_field()->name() != right.key_field()->name() || + left.value_field()->name() != right.value_field()->name())) { + result_ = false; + return Status::OK(); + } + if (check_metadata_ && !(MetadataEqual(*left.item_field(), *right.item_field()) && + MetadataEqual(*left.key_field(), *right.key_field()) && + MetadataEqual(*left.value_field(), *right.value_field()))) { + result_ = false; + return Status::OK(); + } result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) && left.item_type()->Equals(*right.item_type(), check_metadata_); return Status::OK(); diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 99e8b89f1ca13..13bf6f85a4874 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -103,8 +103,6 @@ class CastMetaFunction : public MetaFunction { if (!is_nested(args[0].type()->id())) { return args[0]; } else if (args[0].is_array()) { - // TODO(ARROW-14999): if types are equal except for field names of list - // types, we can also use this code path. ARROW_ASSIGN_OR_RAISE(std::shared_ptr array, ::arrow::internal::GetArrayView( args[0].array(), cast_options->to_type.owned_type)); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ea9525404c816..1e0001c208f99 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -412,11 +412,11 @@ bool DataType::Equals(const DataType& other, bool check_metadata) const { return TypeEquals(*this, other, check_metadata); } -bool DataType::Equals(const std::shared_ptr& other) const { +bool DataType::Equals(const std::shared_ptr& other, bool check_metadata) const { if (!other) { return false; } - return Equals(*other.get()); + return Equals(*other.get(), check_metadata); } size_t DataType::Hash() const { @@ -2090,6 +2090,7 @@ std::string DataType::ComputeMetadataFingerprint() const { // Whatever the data type, metadata can only be found on child fields std::string s; for (const auto& child : children_) { + s += child->name() + "="; s += child->metadata_fingerprint() + ";"; } return s; @@ -2136,17 +2137,33 @@ std::string DictionaryType::ComputeFingerprint() const { } std::string ListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { - return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}"; + std::stringstream ss; + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << '{' << child_fingerprint << '}'; + return ss.str(); } return ""; } std::string LargeListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { - return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}"; + std::stringstream ss; + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << '{' << child_fingerprint << '}'; + return ss.str(); } return ""; } @@ -2155,20 +2172,33 @@ std::string MapType::ComputeFingerprint() const { const auto& key_fingerprint = key_type()->fingerprint(); const auto& item_fingerprint = item_type()->fingerprint(); if (!key_fingerprint.empty() && !item_fingerprint.empty()) { + std::stringstream ss; + ss << TypeIdFingerprint(*this); if (keys_sorted_) { - return TypeIdFingerprint(*this) + "s{" + key_fingerprint + item_fingerprint + "}"; + ss << 's'; + } + if (item_field()->nullable()) { + ss << 'n'; } else { - return TypeIdFingerprint(*this) + "{" + key_fingerprint + item_fingerprint + "}"; + ss << 'N'; } + ss << '{' << key_fingerprint + item_fingerprint << '}'; + return ss.str(); } return ""; } std::string FixedSizeListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { std::stringstream ss; - ss << TypeIdFingerprint(*this) << "[" << list_size_ << "]" + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << "[" << list_size_ << "]" << "{" << child_fingerprint << "}"; return ss.str(); } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 415aaacf1c9ef..4bf8fe7fabb9b 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -140,7 +140,7 @@ class ARROW_EXPORT DataType : public std::enable_shared_from_this, bool Equals(const DataType& other, bool check_metadata = false) const; /// \brief Return whether the types are equal - bool Equals(const std::shared_ptr& other) const; + bool Equals(const std::shared_ptr& other, bool check_metadata = false) const; /// \brief Return the child field at index i. const std::shared_ptr& field(int i) const { return children_[i]; } diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 954ad63c8aa68..36206e68f8b8e 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1262,6 +1262,8 @@ TEST(TestLargeListType, Basics) { } TEST(TestMapType, Basics) { + auto md = key_value_metadata({"foo"}, {"foo value"}); + std::shared_ptr kt = std::make_shared(); std::shared_ptr it = std::make_shared(); @@ -1294,6 +1296,41 @@ TEST(TestMapType, Basics) { "some_entries", struct_({field("some_key", kt, false), field("some_value", mt)}), false))); AssertTypeEqual(mt3, *mt5); + // ...unless we explicitly ask about them. + ASSERT_FALSE(mt3.Equals(mt5, /*check_metadata=*/true)); + + // nullability of value type matters in comparisons + MapType map_type_non_nullable(kt, field("value", it, /*nullable=*/false)); + AssertTypeNotEqual(map_type, map_type_non_nullable); +} + +TEST(TestMapType, Metadata) { + auto md1 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"}); + auto md2 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"}); + auto md3 = key_value_metadata({"foo"}, {"foo value"}); + + auto t1 = map(utf8(), field("value", int32(), md1)); + auto t2 = map(utf8(), field("value", int32(), md2)); + auto t3 = map(utf8(), field("value", int32(), md3)); + auto t4 = + std::make_shared(field("key", utf8(), md1), field("value", int32(), md2)); + ASSERT_OK_AND_ASSIGN(auto t5, + MapType::Make(field("some_entries", + struct_({field("some_key", utf8(), false), + field("some_value", int32(), md2)}), + false, md2))); + + AssertTypeEqual(*t1, *t2); + AssertTypeEqual(*t1, *t2, /*check_metadata=*/true); + + AssertTypeEqual(*t1, *t3); + AssertTypeNotEqual(*t1, *t3, /*check_metadata=*/true); + + AssertTypeEqual(*t1, *t4); + AssertTypeNotEqual(*t1, *t4, /*check_metadata=*/true); + + AssertTypeEqual(*t1, *t5); + AssertTypeNotEqual(*t1, *t5, /*check_metadata=*/true); } TEST(TestFixedSizeListType, Basics) { @@ -1478,15 +1515,26 @@ TEST(TestListType, Equals) { auto t1 = list(utf8()); auto t2 = list(utf8()); auto t3 = list(binary()); - auto t4 = large_list(binary()); - auto t5 = large_list(binary()); - auto t6 = large_list(float64()); + auto t4 = list(field("item", utf8(), /*nullable=*/false)); + auto tl1 = large_list(binary()); + auto tl2 = large_list(binary()); + auto tl3 = large_list(float64()); AssertTypeEqual(*t1, *t2); AssertTypeNotEqual(*t1, *t3); - AssertTypeNotEqual(*t3, *t4); - AssertTypeEqual(*t4, *t5); - AssertTypeNotEqual(*t5, *t6); + AssertTypeNotEqual(*t1, *t4); + AssertTypeNotEqual(*t3, *tl1); + AssertTypeEqual(*tl1, *tl2); + AssertTypeNotEqual(*tl2, *tl3); + + std::shared_ptr vt = std::make_shared(); + std::shared_ptr inner_field = std::make_shared("non_default_name", vt); + + ListType list_type(vt); + ListType list_type_named(inner_field); + + AssertTypeEqual(list_type, list_type_named); + ASSERT_FALSE(list_type.Equals(list_type_named, /*check_metadata=*/true)); } TEST(TestListType, Metadata) { diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index 33ff1cf4a9af5..636269209fbe8 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -142,7 +142,7 @@ def round_trip_field(self, field_generator): expected = field_generator() self.assertEqual(expected, new_field) - def round_trip_array(self, array_generator, expected_diff=None): + def round_trip_array(self, array_generator, check_metadata=True): original_arr = array_generator() with self.bridge.java_c.CDataDictionaryProvider() as dictionary_provider, \ self.bridge.python_to_java_array(original_arr, dictionary_provider) as vector: @@ -150,9 +150,10 @@ def round_trip_array(self, array_generator, expected_diff=None): new_array = self.bridge.java_to_python_array(vector, dictionary_provider) expected = array_generator() - if expected_diff: - self.assertEqual(expected, new_array.view(expected.type)) - self.assertEqual(expected.diff(new_array), expected_diff or '') + + self.assertEqual(expected, new_array) + if check_metadata: + self.assertTrue(new_array.equals(expected, check_metadata=True)) def round_trip_record_batch(self, rb_generator): original_rb = rb_generator() @@ -191,7 +192,7 @@ def test_int_array(self): def test_list_array(self): self.round_trip_array(lambda: pa.array( [[], [0], [1, 2], [4, 5, 6]], pa.list_(pa.int64()) - ), "# Array types differed: list vs list<$data$: int64>\n") + ), check_metadata=False) def test_struct_array(self): fields = [ diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e2346c6346129..279058420f460 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -153,8 +153,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDataType" arrow::DataType": Type id() - c_bool Equals(const CDataType& other) - c_bool Equals(const shared_ptr[CDataType]& other) + c_bool Equals(const CDataType& other, c_bool check_metadata) + c_bool Equals(const shared_ptr[CDataType]& other, c_bool check_metadata) shared_ptr[CField] field(int i) const vector[shared_ptr[CField]] fields() diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index e922ca0e1caf6..c780cd80c7928 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -518,6 +518,21 @@ def test_list_type(): assert ty.value_type == pa.int64() assert ty.value_field == pa.field("item", pa.int64(), nullable=True) + # nullability matters in comparison + ty_non_nullable = pa.list_(pa.field("item", pa.int64(), nullable=False)) + assert ty != ty_non_nullable + + # field names don't matter by default + ty_named = pa.list_(pa.field("element", pa.int64())) + assert ty == ty_named + assert not ty.equals(ty_named, check_metadata=True) + + # metadata doesn't matter by default + ty_metadata = pa.list_( + pa.field("item", pa.int64(), metadata={"hello": "world"})) + assert ty == ty_metadata + assert not ty.equals(ty_metadata, check_metadata=True) + with pytest.raises(TypeError): pa.list_(None) @@ -540,6 +555,23 @@ def test_map_type(): assert ty.item_type == pa.int32() assert ty.item_field == pa.field("value", pa.int32(), nullable=True) + # nullability matters in comparison + ty_non_nullable = pa.map_(pa.utf8(), pa.field( + "value", pa.int32(), nullable=False)) + assert ty != ty_non_nullable + + # field names don't matter by default + ty_named = pa.map_(pa.field("x", pa.utf8(), nullable=False), + pa.field("y", pa.int32())) + assert ty == ty_named + assert not ty.equals(ty_named, check_metadata=True) + + # metadata doesn't matter by default + ty_metadata = pa.map_(pa.utf8(), pa.field( + "value", pa.int32(), metadata={"hello": "world"})) + assert ty == ty_metadata + assert not ty.equals(ty_metadata, check_metadata=True) + with pytest.raises(TypeError): pa.map_(None) with pytest.raises(TypeError): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 8d5b261acb967..40440a2f3de7d 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -192,22 +192,27 @@ cdef class DataType(_Weakrefable): except (TypeError, ValueError): return NotImplemented - def equals(self, other): + def equals(self, other, check_metadata=False): """ Return true if type is equivalent to passed value. Parameters ---------- other : DataType or string convertible to DataType + check_metadata : bool + Whether nested Field metadata equality should be checked as well. Returns ------- is_equal : bool """ - cdef DataType other_type + cdef: + DataType other_type + c_bool c_check_metadata other_type = ensure_type(other) - return self.type.Equals(deref(other_type.type)) + c_check_metadata = check_metadata + return self.type.Equals(deref(other_type.type), c_check_metadata) def to_pandas_dtype(self): """ @@ -870,7 +875,7 @@ cdef class BaseExtensionType(DataType): f"Expected array or chunked array, got {storage.__class__}") if not c_storage_type.get().Equals(deref(self.ext_type) - .storage_type()): + .storage_type(), False): raise TypeError( f"Incompatible storage type for {self}: " f"expected {self.storage_type}, got {storage.type}") diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 144044d7e74e0..7f219fddc35ab 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -936,8 +936,8 @@ DataType__name <- function(type) { .Call(`_arrow_DataType__name`, type) } -DataType__Equals <- function(lhs, rhs) { - .Call(`_arrow_DataType__Equals`, lhs, rhs) +DataType__Equals <- function(lhs, rhs, check_metadata) { + .Call(`_arrow_DataType__Equals`, lhs, rhs, check_metadata) } DataType__num_fields <- function(type) { diff --git a/r/R/type.R b/r/R/type.R index cda606e3fa955..ddd39e5c11602 100644 --- a/r/R/type.R +++ b/r/R/type.R @@ -37,8 +37,8 @@ DataType <- R6Class("DataType", ToString = function() { DataType__ToString(self) }, - Equals = function(other, ...) { - inherits(other, "DataType") && DataType__Equals(self, other) + Equals = function(other, check_metadata = FALSE, ...) { + inherits(other, "DataType") && DataType__Equals(self, other, isTRUE(check_metadata)) }, fields = function() { DataType__fields(self) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index d3f97f5a99f74..7b3e4be90a3bd 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2426,12 +2426,13 @@ BEGIN_CPP11 END_CPP11 } // datatype.cpp -bool DataType__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_DataType__Equals(SEXP lhs_sexp, SEXP rhs_sexp){ +bool DataType__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs, bool check_metadata); +extern "C" SEXP _arrow_DataType__Equals(SEXP lhs_sexp, SEXP rhs_sexp, SEXP check_metadata_sexp){ BEGIN_CPP11 arrow::r::Input&>::type lhs(lhs_sexp); arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(DataType__Equals(lhs, rhs)); + arrow::r::Input::type check_metadata(check_metadata_sexp); + return cpp11::as_sexp(DataType__Equals(lhs, rhs, check_metadata)); END_CPP11 } // datatype.cpp @@ -5511,7 +5512,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_struct__", (DL_FUNC) &_arrow_struct__, 1}, { "_arrow_DataType__ToString", (DL_FUNC) &_arrow_DataType__ToString, 1}, { "_arrow_DataType__name", (DL_FUNC) &_arrow_DataType__name, 1}, - { "_arrow_DataType__Equals", (DL_FUNC) &_arrow_DataType__Equals, 2}, + { "_arrow_DataType__Equals", (DL_FUNC) &_arrow_DataType__Equals, 3}, { "_arrow_DataType__num_fields", (DL_FUNC) &_arrow_DataType__num_fields, 1}, { "_arrow_DataType__fields", (DL_FUNC) &_arrow_DataType__fields, 1}, { "_arrow_DataType__id", (DL_FUNC) &_arrow_DataType__id, 1}, diff --git a/r/src/datatype.cpp b/r/src/datatype.cpp index dc8d3b18926ae..959556c9f174d 100644 --- a/r/src/datatype.cpp +++ b/r/src/datatype.cpp @@ -327,8 +327,9 @@ std::string DataType__name(const std::shared_ptr& type) { // [[arrow::export]] bool DataType__Equals(const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return lhs->Equals(*rhs); + const std::shared_ptr& rhs, + bool check_metadata) { + return lhs->Equals(*rhs, check_metadata); } // [[arrow::export]] diff --git a/r/tests/testthat/test-data-type.R b/r/tests/testthat/test-data-type.R index 16fcf8e0a38cb..0f193f19d3733 100644 --- a/r/tests/testthat/test-data-type.R +++ b/r/tests/testthat/test-data-type.R @@ -365,6 +365,14 @@ test_that("list type works as expected", { ) expect_equal(x$value_type, int32()) expect_equal(x$value_field, field("item", int32())) + + # nullability matters in comparison + expect_false(x$Equals(list_of(field("item", int32(), nullable = FALSE)))) + + # field names don't matter by default + other_name <- list_of(field("other", int32())) + expect_equal(x, other_name, ignore_attr = TRUE) + expect_false(x$Equals(other_name, check_metadata = TRUE)) }) test_that("map type works as expected", { @@ -388,6 +396,14 @@ test_that("map type works as expected", { # we can make this comparison: # expect_equal(x$value_type, struct(key = x$key_field, value = x$item_field)) # nolint expect_false(x$keys_sorted) + + # nullability matters in comparison + expect_false(x$Equals(map_of(int32(), field("value", utf8(), nullable = FALSE)))) + + # field names don't matter by default + other_name <- map_of(int32(), field("other", utf8())) + expect_equal(x, other_name, ignore_attr = TRUE) + expect_false(x$Equals(other_name, check_metadata = TRUE)) }) test_that("map type validates arguments", { diff --git a/r/tests/testthat/test-parquet.R b/r/tests/testthat/test-parquet.R index 32170534a47c3..591805d4ff5ec 100644 --- a/r/tests/testthat/test-parquet.R +++ b/r/tests/testthat/test-parquet.R @@ -457,9 +457,8 @@ test_that("Can read parquet with nested lists and maps", { skip_if_not(dir.exists(parquet_test_data), "Parquet test data missing") pq <- read_parquet(paste0(parquet_test_data, "/nested_lists.snappy.parquet"), as_data_frame = FALSE) - # value name is "element" from parquet reader, but type default is "item" - expect_equal(pq$a$type, list_of(field("element", list_of(field("element", list_of(field("element", utf8()))))))) + expect_equal(pq$a$type, list_of(list_of(list_of(utf8())))) pq <- read_parquet(paste0(parquet_test_data, "/nested_maps.snappy.parquet"), as_data_frame = FALSE) - expect_equal(pq$a$type, map_of(utf8(), map_of(int32(), boolean()))) + expect_equal(pq$a$type, map_of(utf8(), map_of(int32(), field("val", boolean(), nullable = FALSE)))) }) From 08fe81ac12638c4ec7787d60f00ec143fc1d87b9 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 6 Dec 2022 08:01:11 -0800 Subject: [PATCH 2/4] fix: add type --- java/c/src/test/python/integration_tests.py | 2 +- r/src/datatype.cpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index 636269209fbe8..a84381e858f50 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -153,7 +153,7 @@ def round_trip_array(self, array_generator, check_metadata=True): self.assertEqual(expected, new_array) if check_metadata: - self.assertTrue(new_array.equals(expected, check_metadata=True)) + self.assertTrue(new_array.type.equals(expected.type, check_metadata=True)) def round_trip_record_batch(self, rb_generator): original_rb = rb_generator() diff --git a/r/src/datatype.cpp b/r/src/datatype.cpp index 959556c9f174d..f19ba92527157 100644 --- a/r/src/datatype.cpp +++ b/r/src/datatype.cpp @@ -327,8 +327,7 @@ std::string DataType__name(const std::shared_ptr& type) { // [[arrow::export]] bool DataType__Equals(const std::shared_ptr& lhs, - const std::shared_ptr& rhs, - bool check_metadata) { + const std::shared_ptr& rhs, bool check_metadata) { return lhs->Equals(*rhs, check_metadata); } From 58a36c858d0291a54a84024a4fabb4e154cd3b9e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 7 Dec 2022 12:09:21 -0800 Subject: [PATCH 3/4] Update python/pyarrow/types.pxi Co-authored-by: Antoine Pitrou --- python/pyarrow/types.pxi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 40440a2f3de7d..d771ac7351708 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -192,7 +192,7 @@ cdef class DataType(_Weakrefable): except (TypeError, ValueError): return NotImplemented - def equals(self, other, check_metadata=False): + def equals(self, other, *, check_metadata=False): """ Return true if type is equivalent to passed value. From 4da101305120b0d30aa3ec93e5568e168249af0e Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 7 Dec 2022 12:13:52 -0800 Subject: [PATCH 4/4] chore: add internal comments --- cpp/src/arrow/type.cc | 3 +++ java/c/src/test/python/integration_tests.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 1e0001c208f99..cc31735512bad 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -2090,6 +2090,9 @@ std::string DataType::ComputeMetadataFingerprint() const { // Whatever the data type, metadata can only be found on child fields std::string s; for (const auto& child : children_) { + // Add field name to metadata fingerprint so that the field names within + // list and map types are included as part of the metadata. They are + // excluded from the base fingerprint. s += child->name() + "="; s += child->metadata_fingerprint() + ";"; } diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index a84381e858f50..c23b4b9b4416e 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -192,7 +192,10 @@ def test_int_array(self): def test_list_array(self): self.round_trip_array(lambda: pa.array( [[], [0], [1, 2], [4, 5, 6]], pa.list_(pa.int64()) + # disabled check_metadata since the list internal field name ("item") + # is not preserved during round trips (it becomes "$data$"). ), check_metadata=False) + def test_struct_array(self): fields = [