Skip to content

Commit

Permalink
[Pallas] Support multiple outputs (#6844)
Browse files Browse the repository at this point in the history
Summary:
This pull request support Pallas kernels that output multiple results. The current implementation is to support an array of outputs and then do in-place updates to them. However, this somehow breaks dynamo. I will fix the dynamo issue later.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_multiple_returns
  • Loading branch information
alanwaketan authored Mar 29, 2024
1 parent 7bbe9d7 commit 8f095fc
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 39 deletions.
46 changes: 37 additions & 9 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_tpu_custom_call_pallas_add(self):
expected_output = x + y
output = torch.arange(8, dtype=torch.int).to("xla")

torch_xla._XLAC._xla_tpu_custom_call_(output, [x, y], payload)
torch_xla._XLAC._xla_tpu_custom_call_([output], [x, y], payload)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Expand All @@ -46,7 +46,7 @@ def test_tpu_custom_call_pallas_add_one(self):
expected_output = x + 1
output = torch.arange(8, dtype=torch.int).to("xla")

torch_xla._XLAC._xla_tpu_custom_call_(output, [x], payload)
torch_xla._XLAC._xla_tpu_custom_call_([output], [x], payload)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Expand All @@ -60,7 +60,7 @@ def test_tpu_custom_call_pallas_raise(self):

# _xla_tpu_custom_call_ requires at least one input.
with self.assertRaises(RuntimeError):
torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload)
torch_xla._XLAC._xla_tpu_custom_call_([output], [], payload)
output.cpu()

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
Expand All @@ -87,11 +87,11 @@ def attention(q, k, v):

expected_o = attention(q, k, v)

torch_xla._XLAC._xla_tpu_custom_call_(o, [q, k, v], payload)
torch_xla._XLAC._xla_tpu_custom_call_([o], [q, k, v], payload)
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
# TODO: Make the tpu_custom_call_ as functional.
@unittest.skip("TODO: Make the tpu_custom_call_ as functional.")
@unittest.mock.patch.dict(os.environ, {"XLA_DISABLE_FUNCTIONALIZATION": "1"})
def test_tpu_custom_call_pallas_add_one_dynamo(self):
# This payload is generated by the following Pallas code:
Expand All @@ -106,7 +106,7 @@ def test_tpu_custom_call_pallas_add_one_dynamo(self):
import torch_xla.experimental.custom_kernel

def add_one_pallas(output, inputs, payload):
torch.ops.xla.tpu_custom_call_(output, inputs, payload)
torch.ops.xla.tpu_custom_call(output, inputs, payload)

compiled_add_one_pallas = torch.compile(
add_one_pallas, backend='openxla', fullgraph=True)
Expand Down Expand Up @@ -150,8 +150,8 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
x.dtype))(x, y)

from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(add_vectors, lambda x, y:
(x.shape, x.dtype))
pt_kernel = make_kernel_from_pallas(add_vectors,
lambda x, y: [(x.shape, x.dtype)])

dtypes = [
torch.float32, torch.float
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_tpu_custom_call_pallas_wrap_flash_attention(self):
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
flash_attention_kernel = make_kernel_from_pallas(
flash_attention, lambda q, k, v: (q.shape, q.dtype))
flash_attention, lambda q, k, v: [(q.shape, q.dtype)])

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
Expand Down Expand Up @@ -255,6 +255,34 @@ def test_flash_attention_wrapper_bf16(self):
# No exception being raised.
o = flash_attention(q, k, v)

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_multiple_returns(self):
import jax._src.pallas.mosaic.pallas_call_registration

def add_minus_vectors_kernel(x_ref, y_ref, o1_ref, o2_ref):
x, y = x_ref[...], y_ref[...]
o1_ref[...] = x + y
o2_ref[...] = x - y

@jax.jit
def add_minus_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return pl.pallas_call(
add_minus_vectors_kernel, out_shape=[out_shape, out_shape])(x, y)

from torch_xla.experimental.custom_kernel import make_kernel_from_pallas
pt_kernel = make_kernel_from_pallas(
add_minus_vectors, lambda x, y: [(x.shape, x.dtype),
(x.shape, x.dtype)])
x = torch.arange(8, device="xla", dtype=torch.float)
o = pt_kernel(x, x)
self.assertEqual(len(o), 2)

expected_o0 = x + x
expected_o1 = x - x
self.assertTrue(torch.allclose(o[0].cpu(), expected_o0.cpu()))
self.assertTrue(torch.allclose(o[1].cpu(), expected_o1.cpu()))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2253,11 +2253,11 @@ void InitXlaModuleBindings(py::module m) {
return XLANativeFunctions::set_(self, source);
});
m.def("_xla_tpu_custom_call_",
[](at::Tensor& output, const std::vector<at::Tensor>& inputs,
const std::string& payload) {
auto x_output = bridge::GetXlaTensor(output);
[](const std::vector<at::Tensor>& outputs,
const std::vector<at::Tensor>& inputs, const std::string& payload) {
auto x_outputs = bridge::GetXlaTensors(outputs);
return tensor_methods::tpu_custom_call_(
x_output, bridge::GetXlaTensors(inputs), payload);
x_outputs, bridge::GetXlaTensors(inputs), payload);
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/ops/tpu_custom_call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ TpuCustomCall::TpuCustomCall(torch::lazy::OpList inputs,
xla::Shape output_shape,
const std::string& payload)
: XlaNode(xla_tpu_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/1, torch::lazy::MHash(payload)),
/*num_outputs=*/output_shape.tuple_shapes_size(),
torch::lazy::MHash(payload)),
payload_(payload) {}

torch::lazy::NodePtr TpuCustomCall::Clone(torch::lazy::OpList operands) const {
Expand All @@ -23,8 +24,8 @@ XlaOpVector TpuCustomCall::Lower(LoweringContext* loctx) const {
for (auto& operand : operands()) {
inputs.push_back(loctx->GetOutputOp(operand));
}
xla::XlaOp output = BuildTpuCustomCall(inputs, xla_shape(), payload_);
return ReturnOp(output, loctx);
auto output = BuildTpuCustomCall(inputs, xla_shape(), payload_);
return ReturnOps(output, loctx);
}

std::string TpuCustomCall::ToString() const {
Expand Down
16 changes: 13 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,25 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

void tpu_custom_call_(XLATensorPtr& output,
void tpu_custom_call_(const std::vector<XLATensorPtr>& outputs,
const std::vector<XLATensorPtr>& inputs,
const std::string& payload) {
std::vector<torch::lazy::Value> values;
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}
output->SetInPlaceIrValue(torch::lazy::MakeNode<TpuCustomCall>(
values, output->shape().get(), payload));

// TODO: Let's see if we can do some shape inference here.
std::vector<xla::Shape> output_shapes;
for (const auto& output : outputs) {
output_shapes.push_back(output->shape().get());
}

auto node = torch::lazy::MakeNode<TpuCustomCall>(
values, xla::ShapeUtil::MakeTupleShape(output_shapes), payload);
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i));
}
}

XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
void custom_sharding_(const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& spec);

void tpu_custom_call_(XLATensorPtr& output,
void tpu_custom_call_(const std::vector<XLATensorPtr>& output,
const std::vector<XLATensorPtr>& inputs,
const std::string& payload);

Expand Down
44 changes: 33 additions & 11 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1250,9 +1250,12 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
{input}, ShapeHelper::ShapeOfXlaOp(input));
}

xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload) {
std::vector<xla::XlaOp> BuildTpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload) {
XLA_CHECK(inputs.size() > 0) << "inputs are empty";
XLA_CHECK(output_shape.IsTuple()) << "output_shape is not a tuple";

// We need to enforce the default C-order (major-to-minor) layouts for inputs
// to Mosaic and outputs from Mosaic.
std::vector<xla::Shape> input_shapes;
Expand All @@ -1262,15 +1265,34 @@ xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
input_shapes.push_back(MakeTorchTensorLayout(
shape.dimensions(), shape.dynamic_dimensions(), shape.element_type()));
}
xla::Shape output_shape_impl = MakeTorchTensorLayout(
output_shape.dimensions(), output_shape.dynamic_dimensions(),
output_shape.element_type());

XLA_CHECK(inputs.size() > 0) << "inputs are empty";
return xla::CustomCallWithLayout(inputs[0].builder(),
/*call_target_name=*/"tpu_custom_call",
inputs, output_shape_impl, input_shapes,
payload);
std::vector<xla::Shape> output_shapes;
output_shapes.reserve(output_shape.tuple_shapes_size());
for (int i = 0; i < output_shape.tuple_shapes_size(); ++i) {
const xla::Shape& shape = output_shape.tuple_shapes(i);
output_shapes.push_back(MakeTorchTensorLayout(
shape.dimensions(), shape.dynamic_dimensions(), shape.element_type()));
}

// Mosaic has some weird checks that disallow using a tuple output for single
// element.
if (output_shapes.size() == 1) {
return {xla::CustomCallWithLayout(inputs[0].builder(),
/*call_target_name=*/"tpu_custom_call",
inputs, output_shapes[0], input_shapes,
payload)};
}

xla::XlaOp outputs = xla::CustomCallWithLayout(
inputs[0].builder(),
/*call_target_name=*/"tpu_custom_call", inputs,
xla::ShapeUtil::MakeTupleShape(output_shapes), input_shapes, payload);
std::vector<xla::XlaOp> result;
result.reserve(output_shapes.size());
for (int i = 0; i < output_shapes.size(); ++i) {
result.push_back(xla::GetTupleElement(outputs, i));
}
return result;
}

} // namespace torch_xla
6 changes: 3 additions & 3 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ xla::XlaOp BuildUpperTriangle(xla::XlaOp input);

xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);

xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload);
std::vector<xla::XlaOp> BuildTpuCustomCall(
const std::vector<xla::XlaOp>& inputs, const xla::Shape& output_shape,
const std::string& payload);

} // namespace torch_xla

Expand Down
20 changes: 15 additions & 5 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,20 @@ def wrapped_kernel(kernel: Callable,
kernel, static_argnames=static_argnames).lower(*jax_args,
**kwargs).compiler_ir()
payload = _extract_backend_config(ir)
output_shape, output_dtype = output_shape_dtype_fn(*args)
output = torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device())
torch_xla._XLAC._xla_tpu_custom_call_(output, args, payload)
return output
# TODO: We can consider supporting un-array output.
outputs = []
output_shape_dtype = output_shape_dtype_fn(*args)
assert isinstance(output_shape_dtype,
list), "The output_shape_dtype_fn should return a list."
for output_shape, output_dtype in output_shape_dtype:
outputs.append(
torch.empty(output_shape, dtype=output_dtype).to(xm.xla_device()))
torch_xla._XLAC._xla_tpu_custom_call_(outputs, args, payload)

# Make the output easier to use.
if len(outputs) == 1:
return outputs[0]
return tuple(outputs)

return functools.partial(wrapped_kernel, kernel, output_shape_dtype_fn)

Expand All @@ -150,7 +160,7 @@ def flash_attention(

# TODO: Support segment_ids.
flash_attention_kernel = make_kernel_from_pallas(
tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype))
tpu_flash_attention.flash_attention, lambda q, k, v: [(q.shape, q.dtype)])

# The block_sizes configuration is copied from https://github.com/google/maxtext/blob/0fee320451738166c8e596dc63a57a4673671576/MaxText/layers/attentions.py#L215-L240
# It yields much better performance than the default block_sizes.
Expand Down

0 comments on commit 8f095fc

Please sign in to comment.