Skip to content

Commit

Permalink
document kwargs + raise error when trying to cast non-CPU data
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Jun 26, 2024
1 parent b2ad739 commit 671efda
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
9 changes: 8 additions & 1 deletion python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,10 @@ cdef class Array(_PandasConvertible):
schema. PyArrow will attempt to cast the array to this data type.
If None, the array will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.
Returns
-------
Expand All @@ -1928,7 +1932,10 @@ cdef class Array(_PandasConvertible):
target_type = DataType._import_from_c_capsule(requested_schema)

if target_type != self.type:
# TODO should protect from trying to cast non-CPU data
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_array = _pc().cast(self, target_type, safe=True)
inner_array = pyarrow_unwrap_array(casted_array)
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
int num_columns()
int64_t num_rows()

CDeviceAllocationType device_type()

CStatus Validate() const
CStatus ValidateFull() const

Expand Down
27 changes: 26 additions & 1 deletion python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -3782,6 +3782,10 @@ cdef class RecordBatch(_Tabular):
schema. PyArrow will attempt to cast the batch to this data type.
If None, the batch will be returned as-is, with a type matching the
one returned by :meth:`__arrow_c_schema__()`.
kwargs
Currently no additional keyword arguments are supported, but
this method will accept any keyword with a value of ``None``
for compatibility with future keywords.
Returns
-------
Expand All @@ -3806,7 +3810,10 @@ cdef class RecordBatch(_Tabular):
target_schema = Schema._import_from_c_capsule(requested_schema)

if target_schema != self.schema:
# TODO should protect from trying to cast non-CPU data
if not self.is_cpu:
raise NotImplementedError(
"Casting to a requested schema is only supported for CPU data"
)
try:
casted_batch = self.cast(target_schema, safe=True)
inner_batch = pyarrow_unwrap_batch(casted_batch)
Expand Down Expand Up @@ -3860,6 +3867,24 @@ cdef class RecordBatch(_Tabular):

return pyarrow_wrap_batch(batch)

@property
def device_type(self):
"""
The device type where the arrays in the RecordBatch reside.
Returns
-------
DeviceAllocationType
"""
return _wrap_device_allocation_type(self.sp_batch.get().device_type())

@property
def is_cpu(self):
"""
Whether the RecordBatch's arrays are CPU-accessible.
"""
return self.device_type == DeviceAllocationType.CPU


def _reconstruct_record_batch(columns, schema):
"""
Expand Down

0 comments on commit 671efda

Please sign in to comment.