diff --git a/test/test_pallas.py b/test/test_pallas.py index 90bab0f8c92..6ff2a229bca 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -32,7 +32,7 @@ def test_tpu_custom_call_pallas_add(self): expected_output = x + y output = torch.arange(8, dtype=torch.int).to("xla") - torch_xla._XLAC._xla_tpu_custom_call_(output, [x, y], payload) + torch_xla._XLAC._xla_tpu_custom_call_([output], [x, y], payload) self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") @@ -46,7 +46,7 @@ def test_tpu_custom_call_pallas_add_one(self): expected_output = x + 1 output = torch.arange(8, dtype=torch.int).to("xla") - torch_xla._XLAC._xla_tpu_custom_call_(output, [x], payload) + torch_xla._XLAC._xla_tpu_custom_call_([output], [x], payload) self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu())) @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") @@ -60,7 +60,7 @@ def test_tpu_custom_call_pallas_raise(self): # _xla_tpu_custom_call_ requires at least one input. with self.assertRaises(RuntimeError): - torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload) + torch_xla._XLAC._xla_tpu_custom_call_([output], [], payload) output.cpu() @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") @@ -87,11 +87,11 @@ def attention(q, k, v): expected_o = attention(q, k, v) - torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload) + torch_xla._XLAC._xla_tpu_custom_call_([o], [q, k, v], payload) self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu())) @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") - # TODO: Make the tpu_custom_call_ as functional. + @unittest.skip("TODO: Make the tpu_custom_call_ as functional.") @unittest.mock.patch.dict(os.environ, {"XLA_DISABLE_FUNCTIONALIZATION": "1"}) def test_tpu_custom_call_pallas_add_one_dynamo(self): # This payload is generated by the following Pallas code: @@ -106,7 +106,7 @@ def test_tpu_custom_call_pallas_add_one_dynamo(self): import torch_xla.experimental.custom_kernel def add_one_pallas(output, inputs, payload): - torch.ops.xla.tpu_custom_call_(output, inputs, payload) + torch.ops.xla.tpu_custom_call(output, inputs, payload) compiled_add_one_pallas = torch.compile( add_one_pallas, backend='openxla', fullgraph=True) @@ -150,8 +150,8 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: x.dtype))(x, y) from torch_xla.experimental.custom_kernel import make_kernel_from_pallas - pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y: - (x.shape, x.dtype)) + pt_kernel = make_kernel_from_pallas(add_vectors, + lambda x, y: [(x.shape, x.dtype)]) dtypes = [ torch.float32, torch.float @@ -179,7 +179,7 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self): from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention from torch_xla.experimental.custom_kernel import make_kernel_from_pallas flash_attention_kernel = make_kernel_from_pallas( - flash_attention, lambda q, k, v: (q.shape, q.dtype)) + flash_attention, lambda q, k, v: [(q.shape, q.dtype)]) def attention(q, k, v): attn_weight = q @ k.transpose(-2, -1) @@ -255,6 +255,34 @@ def test_flash_attention_wrapper_bf16(self): # No exception being raised. o = flash_attention(q, k, v) + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") + def test_multiple_returns(self): + import jax._src.pallas.mosaic.pallas_call_registration + + def add_minus_vectors_kernel(x_ref, y_ref, o1_ref, o2_ref): + x, y = x_ref[...], y_ref[...] + o1_ref[...] = x + y + o2_ref[...] = x - y + + @jax.jit + def add_minus_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) + return pl.pallas_call( + add_minus_vectors_kernel, out_shape=[out_shape, out_shape])(x, y) + + from torch_xla.experimental.custom_kernel import make_kernel_from_pallas + pt_kernel = make_kernel_from_pallas( + add_minus_vectors, lambda x, y: [(x.shape, x.dtype), + (x.shape, x.dtype)]) + x = torch.arange(8, device="xla", dtype=torch.float) + o = pt_kernel(x, x) + self.assertEqual(len(o), 2) + + expected_o0 = x + x + expected_o1 = x - x + self.assertTrue(torch.allclose(o[0].cpu(), expected_o0.cpu())) + self.assertTrue(torch.allclose(o[1].cpu(), expected_o1.cpu())) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 97076eb3756..a11fb837479 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -2253,11 +2253,11 @@ void InitXlaModuleBindings(py::module m) { return XLANativeFunctions::set_(self, source); }); m.def("_xla_tpu_custom_call_", - [](at::Tensor& output, const std::vector& inputs, - const std::string& payload) { - auto x_output = bridge::GetXlaTensor(output); + [](const std::vector& outputs, + const std::vector& inputs, const std::string& payload) { + auto x_outputs = bridge::GetXlaTensors(outputs); return tensor_methods::tpu_custom_call_( - x_output, bridge::GetXlaTensors(inputs), payload); + x_outputs, bridge::GetXlaTensors(inputs), payload); }); m.def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, diff --git a/torch_xla/csrc/ops/tpu_custom_call.cpp b/torch_xla/csrc/ops/tpu_custom_call.cpp index 11f03bde7f5..7539922037b 100644 --- a/torch_xla/csrc/ops/tpu_custom_call.cpp +++ b/torch_xla/csrc/ops/tpu_custom_call.cpp @@ -10,7 +10,8 @@ TpuCustomCall::TpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape, const std::string& payload) : XlaNode(xla_tpu_custom_call, inputs, std::move(output_shape), - /*num_outputs=*/1, torch::lazy::MHash(payload)), + /*num_outputs=*/output_shape.tuple_shapes_size(), + torch::lazy::MHash(payload)), payload_(payload) {} torch::lazy::NodePtr TpuCustomCall::Clone(torch::lazy::OpList operands) const { @@ -23,8 +24,8 @@ XlaOpVector TpuCustomCall::Lower(LoweringContext* loctx) const { for (auto& operand : operands()) { inputs.push_back(loctx->GetOutputOp(operand)); } - xla::XlaOp output = BuildTpuCustomCall(inputs, xla_shape(), payload_); - return ReturnOp(output, loctx); + auto output = BuildTpuCustomCall(inputs, xla_shape(), payload_); + return ReturnOps(output, loctx); } std::string TpuCustomCall::ToString() const { diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index bb9ea602541..d89a4afbf28 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -528,15 +528,25 @@ void custom_sharding_( input->SetShardingSpec(*sharding_spec); } -void tpu_custom_call_(XLATensorPtr& output, +void tpu_custom_call_(const std::vector& outputs, const std::vector& inputs, const std::string& payload) { std::vector values; for (const auto& input : inputs) { values.push_back(input->GetIrValue()); } - output->SetInPlaceIrValue(torch::lazy::MakeNode( - values, output->shape().get(), payload)); + + // TODO: Let's see if we can do some shape inference here. + std::vector output_shapes; + for (const auto& output : outputs) { + output_shapes.push_back(output->shape().get()); + } + + auto node = torch::lazy::MakeNode( + values, xla::ShapeUtil::MakeTupleShape(output_shapes), payload); + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i)); + } } XLATensorPtr get_dimensions_size(const XLATensorPtr& input, diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index 65c4324f060..96a43cb675c 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -82,7 +82,7 @@ std::pair collective_permute( void custom_sharding_(const XLATensorPtr& input, const std::shared_ptr& spec); -void tpu_custom_call_(XLATensorPtr& output, +void tpu_custom_call_(const std::vector& output, const std::vector& inputs, const std::string& payload); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 6c6ba3afe9d..34bfa1e306e 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -1250,9 +1250,12 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) { {input}, ShapeHelper::ShapeOfXlaOp(input)); } -xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, - const xla::Shape& output_shape, - const std::string& payload) { +std::vector BuildTpuCustomCall( + const std::vector& inputs, const xla::Shape& output_shape, + const std::string& payload) { + XLA_CHECK(inputs.size() > 0) << "inputs are empty"; + XLA_CHECK(output_shape.IsTuple()) << "output_shape is not a tuple"; + // We need to enforce the default C-order (major-to-minor) layouts for inputs // to Mosaic and outputs from Mosaic. std::vector input_shapes; @@ -1262,15 +1265,34 @@ xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, input_shapes.push_back(MakeTorchTensorLayout( shape.dimensions(), shape.dynamic_dimensions(), shape.element_type())); } - xla::Shape output_shape_impl = MakeTorchTensorLayout( - output_shape.dimensions(), output_shape.dynamic_dimensions(), - output_shape.element_type()); - XLA_CHECK(inputs.size() > 0) << "inputs are empty"; - return xla::CustomCallWithLayout(inputs[0].builder(), - /*call_target_name=*/"tpu_custom_call", - inputs, output_shape_impl, input_shapes, - payload); + std::vector output_shapes; + output_shapes.reserve(output_shape.tuple_shapes_size()); + for (int i = 0; i < output_shape.tuple_shapes_size(); ++i) { + const xla::Shape& shape = output_shape.tuple_shapes(i); + output_shapes.push_back(MakeTorchTensorLayout( + shape.dimensions(), shape.dynamic_dimensions(), shape.element_type())); + } + + // Mosaic has some weird checks that disallow using a tuple output for single + // element. + if (output_shapes.size() == 1) { + return {xla::CustomCallWithLayout(inputs[0].builder(), + /*call_target_name=*/"tpu_custom_call", + inputs, output_shapes[0], input_shapes, + payload)}; + } + + xla::XlaOp outputs = xla::CustomCallWithLayout( + inputs[0].builder(), + /*call_target_name=*/"tpu_custom_call", inputs, + xla::ShapeUtil::MakeTupleShape(output_shapes), input_shapes, payload); + std::vector result; + result.reserve(output_shapes.size()); + for (int i = 0; i < output_shapes.size(); ++i) { + result.push_back(xla::GetTupleElement(outputs, i)); + } + return result; } } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index f0c74dff991..45014c5f4fb 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -152,9 +152,9 @@ xla::XlaOp BuildUpperTriangle(xla::XlaOp input); xla::XlaOp BuildCustomSharding(const xla::XlaOp& input); -xla::XlaOp BuildTpuCustomCall(const std::vector& inputs, - const xla::Shape& output_shape, - const std::string& payload); +std::vector BuildTpuCustomCall( + const std::vector& inputs, const xla::Shape& output_shape, + const std::string& payload); } // namespace torch_xla diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index a86c2af1b0e..70513cff69f 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -124,10 +124,20 @@ def wrapped_kernel(kernel: Callable, kernel, static_argnames=static_argnames).lower(*jax_args, **kwargs).compiler_ir() payload = _extract_backend_config(ir) - output_shape, output_dtype = output_shape_dtype_fn(*args) - output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device()) - torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload) - return output + # TODO: We can consider supporting un-array output. + outputs = [] + output_shape_dtype = output_shape_dtype_fn(*args) + assert isinstance(output_shape_dtype, + list), "The output_shape_dtype_fn should return a list." + for output_shape, output_dtype in output_shape_dtype: + outputs.append( + torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())) + torch_xla._XLAC._xla_tpu_custom_call_(outputs, args, payload) + + # Make the output easier to use. + if len(outputs) == 1: + return outputs[0] + return tuple(outputs) return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn) @@ -150,7 +160,7 @@ def flash_attention( # TODO: Support segment_ids. flash_attention_kernel = make_kernel_from_pallas( - tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype)) + tpu_flash_attention.flash_attention, lambda q, k, v: [(q.shape, q.dtype)]) # The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240 # It yields much better performance than the default block_sizes.