Skip to content

Commit

Permalink
Return arrays from ArrayImpl._check_and_rearrange.
Browse files Browse the repository at this point in the history
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

Reverts 5633728

PiperOrigin-RevId: 721908845
  • Loading branch information
emilyfertig authored and Google-ML-Automation committed Feb 5, 2025
1 parent 6327932 commit 691173c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
28 changes: 21 additions & 7 deletions xla/python/py_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,8 @@ PyArray_Storage::PyArray_Storage(
prev = nullptr;
}

void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding,
absl::Span<const PyArray> py_arrays, bool committed,
bool skip_checks) {
void PyInit_helper(PyArray self, nb::object aval, nb::object sharding,
absl::Span<const PyArray> py_arrays, bool committed) {
auto dtype = nb::cast<nb_dtype>(aval.attr("dtype"));
auto shape = nb::cast<std::vector<int64_t>>(aval.attr("shape"));
auto ifrt_array =
Expand All @@ -492,9 +491,19 @@ void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding,
std::move(shape), std::move(sharding), committed,
py_arrays.at(0).py_client(), Traceback::Get(),
std::move(ifrt_array), xla::PjRtFuture<>());
}

if (!skip_checks) {
self.CheckAndRearrange();
void PyArray::PyInit(PyArray self, nb::object aval, nb::object sharding,
absl::Span<const PyArray> py_arrays, bool committed,
bool skip_checks) {
if (skip_checks) {
PyInit_helper(self, aval, sharding, py_arrays, committed);
} else {
nb::object rearranged_arrays =
self.CheckAndRearrange(py_arrays, sharding, aval);
auto rearranged_py_arrays =
nb::cast<std::vector<PyArray>>(rearranged_arrays);
PyInit_helper(self, aval, sharding, rearranged_py_arrays, committed);
}
}

Expand Down Expand Up @@ -595,7 +604,8 @@ PyArray::PyArray(nb::object aval, bool weak_type, nb_dtype dtype,
std::move(result_status));

if (!skip_checks) {
CheckAndRearrange();
this->attr("_arrays") = this->attr("_check_and_rearrange")(
this->attr("_arrays"), this->attr("_sharding"), this->attr("aval"));
}
}

Expand All @@ -607,7 +617,11 @@ const PyArray::Storage& PyArray::GetStorage() const {
return *GetPyArrayStorageFromObject(reinterpret_cast<PyArrayObject*>(ptr()));
}

void PyArray::CheckAndRearrange() { this->attr("_check_and_rearrange")(); }
nb::object PyArray::CheckAndRearrange(const absl::Span<const PyArray> py_arrays,
const nb::object sharding,
const nb::object aval) {
return this->attr("_check_and_rearrange")(py_arrays, sharding, aval);
}

void PyArray::SetIfrtArray(tsl::RCReference<ifrt::Array> ifrt_array) {
GetStorage().ifrt_array = std::move(ifrt_array);
Expand Down
4 changes: 3 additions & 1 deletion xla/python/py_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ class PyArray : public nanobind::object {
private:
absl::StatusOr<PyArray> AssertUnsharded(absl::string_view api);

void CheckAndRearrange();
nanobind::object CheckAndRearrange(absl::Span<const PyArray> py_arrays,
nanobind::object sharding,
nanobind::object aval);

void SetIfrtArray(tsl::RCReference<ifrt::Array> ifrt_array);

Expand Down
2 changes: 1 addition & 1 deletion xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 309
_version = 310

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down

0 comments on commit 691173c

Please sign in to comment.