Skip to content

Commit

Permalink
apacheGH-37669: [C++][Python] Fix casting to extension type with fixe…
Browse files Browse the repository at this point in the history
…d size list storage type
  • Loading branch information
jorisvandenbossche committed Jun 20, 2024
1 parent 89d6354 commit a9459e5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_cast_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@ Status CastToExtension(KernelContext* ctx, const ExecSpan& batch, ExecResult* ou
std::shared_ptr<CastFunction> GetCastToExtension(std::string name) {
auto func = std::make_shared<CastFunction>(std::move(name), Type::EXTENSION);
for (Type::type in_ty : AllTypeIds()) {
DCHECK_OK(
func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType, CastToExtension));
DCHECK_OK(func->AddKernel(in_ty, {InputType(in_ty)}, kOutputTargetType,
CastToExtension, NullHandling::COMPUTED_NO_PREALLOCATE,
MemAllocation::NO_PREALLOCATE));
}
return func;
}
Expand Down
60 changes: 60 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,21 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type)


class MyFixedListType(pa.ExtensionType):

def __init__(self, storage_type):
assert isinstance(storage_type, pa.FixedSizeListType)
super().__init__(storage_type, 'pyarrow.tests.MyFixedListType')

def __arrow_ext_serialize__(self):
return b''

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
assert serialized == b''
return cls(storage_type)


class AnnotatedType(pa.ExtensionType):
"""
Generic extension type that can store any storage type.
Expand Down Expand Up @@ -738,6 +753,36 @@ def test_casting_dict_array_to_extension_type():
UUID('30313233-3435-3637-3839-616263646566')]


def test_cast_to_extension_with_nested_storage():
# https://github.com/apache/arrow/issues/37669

# With fixed-size list
array = pa.array([[1, 2], [3, 4], [5, 6]], pa.list_(pa.float64(), 2))
result = array.cast(MyFixedListType(pa.list_(pa.float64(), 2)))
expected = pa.ExtensionArray.from_storage(MyFixedListType(array.type), array)
assert result.equals(expected)

ext_type = MyFixedListType(pa.list_(pa.float32(), 2))
result = array.cast(ext_type)
expected = pa.ExtensionArray.from_storage(
ext_type, array.cast(ext_type.storage_type)
)
assert result.equals(expected)

# With variable-size list
array = pa.array([[1, 2], [3], [4, 5, 6]], pa.list_(pa.float64()))
result = array.cast(MyListType(pa.list_(pa.float64())))
expected = pa.ExtensionArray.from_storage(MyListType(array.type), array)
assert result.equals(expected)

ext_type = MyListType(pa.list_(pa.float32()))
result = array.cast(ext_type)
expected = pa.ExtensionArray.from_storage(
ext_type, array.cast(ext_type.storage_type)
)
assert result.equals(expected)


def test_concat():
arr1 = pa.array([1, 2, 3], IntegerType())
arr2 = pa.array([4, 5, 6], IntegerType())
Expand Down Expand Up @@ -1500,6 +1545,21 @@ def test_tensor_type_equality():
assert not tensor_type == tensor_type3


def test_tensor_type_cast():
tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3])
inner = pa.array(range(18), pa.int8())
storage = pa.FixedSizeListArray.from_arrays(inner, 6)

# cast storage -> extension type
result = storage.cast(tensor_type)
expected = pa.ExtensionArray.from_storage(tensor_type, storage)
assert result.equals(expected)

# cast extension type -> storage type
storage_result = result.cast(storage.type)
assert storage_result.equals(storage)


@pytest.mark.pandas
def test_extension_to_pandas_storage_type(registered_period_type):
period_type, _ = registered_period_type
Expand Down

0 comments on commit a9459e5

Please sign in to comment.