Skip to content

Commit

Permalink
[Pallas] PoC Integration (#6340)
Browse files Browse the repository at this point in the history
Summary:
This is PoC for Pallas integration. Currently, it can run Pallas kernels that take arbitrary tensors as inputs and output a single tensor. The design doc is here: go/pytorch-xla-pallas.

Test Plan:
PJRT_DEVICE=TPU python test/test_operations.py -v -k test_tpu_custom_call
  • Loading branch information
alanwaketan authored Jan 30, 2024
1 parent de5d764 commit 56db8f2
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 0 deletions.
44 changes: 44 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,50 @@ def test_patched_linear_1D_bias(self):
self.assertTrue(
torch.allclose(linear.bias.grad.cpu(), linear_cpu.bias.grad))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, y_ref, o_ref):
# x, y = x_ref[...], y_ref[...]
# o_ref[...] = x + y
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAErCwEDBQcJAQMLAwUDDQcFDxEJBRMVA2lNDQFLBw8LEw8PDwsPMwsLCwtlCwsLCwsPCw8PEwsTDwsTDwsPDxMLDwUDYQENGwcTDxsPAsICHx0rLQUXAwMnKRURNx1HSRELAQUZHTM1AwsVFxkbHw0hDSMlBRsBAQUdDQlhZmZpbmVfbWFwPChkMCkgLT4gKGQwKT4ABR8FIQUjBSUFJxEDAQUpFS8JHQ8xFwUTAQUrFwUdAR05OwUtFwUlAR0/QQUvFUMJHQ9FFwUVAQUxFREJI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF0sDIQcdAycDIQcBAgIFBwEBAQEBAgQEpwUBEAEHAwEFAxEBEwcDFScHAQEBAQEBBwMDBwMDCwYDAwUFAQcHAwMHAwMLBgMDBQUDCwkGPQMFBQkNBwMLBwMDCwYLAwUFBRENBAsHDwURBQABBgMBBQEAdgcz2wsTGdkNCxMjIR0pJ0MNCwsTDw8PDQkLEWJ1aWx0aW4AZnVuYwB0cHUAYXJpdGgAdmVjdG9yAG1vZHVsZQByZXR1cm4AY29uc3RhbnQAYWRkaQBsb2FkAHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AGFkZF92ZWN0b3JzX2tlcm5lbABkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAbWFpbgB2YWx1ZQAvZ2V0W3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQBhZGRfdmVjdG9ycwA8bW9kdWxlPgAvYWRkAC9zd2FwW3RyZWU9UHlUcmVlRGVmKChDdXN0b21Ob2RlKE5ESW5kZXhlclsoUHlUcmVlRGVmKChDdXN0b21Ob2RlKFNsaWNlWygwLCA4KV0sIFtdKSwpKSwgKDgsKSwgKCkpXSwgW10pLCkpXQA=\", \"needs_layout_passes\": true}}"

x = torch.arange(8, dtype=torch.int).to("xla")
y = torch.arange(8, dtype=torch.int).to("xla")
expected_output = x + y
output = torch.arange(8, dtype=torch.int).to("xla")

torch_xla._XLAC._xla_tpu_custom_call_(output, [x, y], payload)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_add_one(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, o_ref):
# o_ref[...] = x_ref[...] + 1
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"

x = torch.arange(8, dtype=torch.int).to("xla")
expected_output = x + 1
output = torch.arange(8, dtype=torch.int).to("xla")

torch_xla._XLAC._xla_tpu_custom_call_(output, [x], payload)
self.assertTrue(torch.allclose(output.cpu(), expected_output.cpu()))

@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
def test_tpu_custom_call_pallas_raise(self):
# This payload is generated by the following Pallas code:
# def add_vectors_kernel(x_ref, o_ref):
# o_ref[...] = x_ref[...] + 1
payload = "{\"custom_call_config\": {\"body\": \"TUzvUgFNTElSMTguMC4wZ2l0AAEtCwEDBQcJAQMLAwUDDQcFDxEJBxMVFwNlSQ0BRwcPCw8PDxMLDzMLCwsLZQsLCwsPCw8LEw8PCxMPCxMTDwsLBQNhAQ0bDxMHFw8CpgIfFSsxBRkdQwMdRQMRCwEDAw8nBRsdKQMDCxUXGRsfCyELIyUFHQEBBR8NCWFmZmluZV9tYXA8KGQwKSAtPiAoZDApPgAFIQUjBSUFJxEHAQUpHS0vBSsXBRsBFTM5HTU3BS0XBS8BHTs9BS8XBUUBAwMPQREDBQUxBTMjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAXRwMhAx0BAgInAyEDAwUFAQEBAQIEBKEFARABBwMBBQMRARMHAxMnBQEBAQEHAxENAwcLBhEDBQUBBQcDBz8DAw0GBwMFAwkJBgcDBQUHCwcDCQ0DBwsGCQMFBQMPDwQJBw0DDwUAAQYDAQUBAMIHNdsLEyEv2QsTIyEdKQ1DDRULCxMPDw8NCQsRYnVpbHRpbgBmdW5jAHRwdQBhcml0aAB2ZWN0b3IAbW9kdWxlAHJldHVybgBjb25zdGFudABhZGRpAGxvYWQAYnJvYWRjYXN0AHN0b3JlAC9ob21lL2p3dGFuL3BhbGxhcy9wYWxsYXNfYWRkLnB5AHZhbHVlAGRpbWVuc2lvbl9zZW1hbnRpY3MAZnVuY3Rpb25fdHlwZQBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBzeW1fbmFtZQBtYWluAC9nZXRbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAGFkZF9vbmVfdmVjdG9yc19rZXJuZWwAYWRkX3ZlY3RvcnNfb25lADxtb2R1bGU+AC9hZGQAL3N3YXBbdHJlZT1QeVRyZWVEZWYoKEN1c3RvbU5vZGUoTkRJbmRleGVyWyhQeVRyZWVEZWYoKEN1c3RvbU5vZGUoU2xpY2VbKDAsIDgpXSwgW10pLCkpLCAoOCwpLCAoKSldLCBbXSksKSldAA==\", \"needs_layout_passes\": true}}"

output = torch.arange(8, dtype=torch.int).to("xla")

# _xla_tpu_custom_call_ requires at least one input.
with self.assertRaises(RuntimeError):
torch_xla._XLAC._xla_tpu_custom_call_(output, [], payload)
output.cpu()


class MNISTComparator(nn.Module):

Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2170,6 +2170,13 @@ void InitXlaModuleBindings(py::module m) {
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});
m.def("_xla_tpu_custom_call_",
[](at::Tensor& output, const std::vector<at::Tensor>& inputs,
const std::string& payload) {
auto x_output = bridge::GetXlaTensor(output);
return tensor_methods::tpu_custom_call_(
x_output, bridge::GetXlaTensors(inputs), payload);
});
m.def("_set_xla_custom_op_name_prefix",
[](const at::Tensor& input, const std::string& op_name_prefix,
size_t max_call_stack_depth) -> bool {
Expand Down
36 changes: 36 additions & 0 deletions torch_xla/csrc/ops/tpu_custom_call.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include "torch_xla/csrc/ops/tpu_custom_call.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/xla_lower_util.h"

namespace torch_xla {

TpuCustomCall::TpuCustomCall(torch::lazy::OpList inputs,
xla::Shape output_shape,
const std::string& payload)
: XlaNode(xla_tpu_custom_call, inputs, std::move(output_shape),
/*num_outputs=*/1, torch::lazy::MHash(payload)),
payload_(payload) {}

torch::lazy::NodePtr TpuCustomCall::Clone(torch::lazy::OpList operands) const {
return torch::lazy::MakeNode<TpuCustomCall>(operands, xla_shape(), payload_);
}

XlaOpVector TpuCustomCall::Lower(LoweringContext* loctx) const {
std::vector<xla::XlaOp> inputs;
inputs.reserve(operands().size());
for (auto& operand : operands()) {
inputs.push_back(loctx->GetOutputOp(operand));
}
xla::XlaOp output = BuildTpuCustomCall(inputs, xla_shape(), payload_);
return ReturnOp(output, loctx);
}

std::string TpuCustomCall::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", " << payload_;
return ss.str();
}

} // namespace torch_xla
26 changes: 26 additions & 0 deletions torch_xla/csrc/ops/tpu_custom_call.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_
#define XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class TpuCustomCall : public XlaNode {
public:
// Make a TPU custom call with payload, e.g., Mosaic.
TpuCustomCall(torch::lazy::OpList inputs, xla::Shape output_shape,
const std::string& payload);

torch::lazy::NodePtr Clone(torch::lazy::OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

private:
std::string payload_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_TPU_CUSTOM_CALL_H_
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@ const OpKindWrapper xla_tensor_data("xla::tensor_data");
const OpKindWrapper xla_unselect("xla::unselect");
const OpKindWrapper xla_update_slice("xla::update_slice");
const OpKindWrapper xla_custom_sharding("xla::custom_sharding");
const OpKindWrapper xla_tpu_custom_call("xla::tpu_custom_call");

} // namespace torch_xla
1 change: 1 addition & 0 deletions torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ extern const OpKindWrapper xla_tensor_data;
extern const OpKindWrapper xla_unselect;
extern const OpKindWrapper xla_update_slice;
extern const OpKindWrapper xla_custom_sharding;
extern const OpKindWrapper xla_tpu_custom_call;

} // namespace torch_xla

Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
#include "torch_xla/csrc/ops/threshold.h"
#include "torch_xla/csrc/ops/threshold_backward.h"
#include "torch_xla/csrc/ops/topk.h"
#include "torch_xla/csrc/ops/tpu_custom_call.h"
#include "torch_xla/csrc/ops/triangular_solve.h"
#include "torch_xla/csrc/ops/uniform.h"
#include "torch_xla/csrc/ops/unsqueeze.h"
Expand Down Expand Up @@ -523,6 +524,17 @@ void custom_sharding_(
input->SetShardingSpec(*sharding_spec);
}

void tpu_custom_call_(XLATensorPtr& output,
const std::vector<XLATensorPtr>& inputs,
const std::string& payload) {
std::vector<torch::lazy::Value> values;
for (const auto& input : inputs) {
values.push_back(input->GetIrValue());
}
output->SetInPlaceIrValue(torch::lazy::MakeNode<TpuCustomCall>(
values, output->shape().get(), payload));
}

XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
std::vector<int64_t> dimensions) {
return input->CreateFrom(torch::lazy::MakeNode<GetDimensionsSize>(
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ std::pair<XLATensorPtr, torch::lazy::Value> collective_permute(
void custom_sharding_(const XLATensorPtr& input,
const std::shared_ptr<XLATensor::ShardingSpec>& spec);

void tpu_custom_call_(XLATensorPtr& output,
const std::vector<XLATensorPtr>& inputs,
const std::string& payload);

XLATensorPtr get_dimensions_size(const XLATensorPtr& input,
std::vector<int64_t> dimensions);

Expand Down
15 changes: 15 additions & 0 deletions torch_xla/csrc/xla_lower_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,4 +1224,19 @@ xla::XlaOp BuildCustomSharding(const xla::XlaOp& input) {
{input}, ShapeHelper::ShapeOfXlaOp(input));
}

xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload) {
std::vector<xla::Shape> input_shapes;
input_shapes.reserve(inputs.size());
for (const auto& input : inputs) {
input_shapes.push_back(ShapeHelper::ShapeOfXlaOp(input));
}

XLA_CHECK(inputs.size() > 0) << "inputs are empty";
return xla::CustomCallWithLayout(inputs[0].builder(),
/*call_target_name=*/"tpu_custom_call",
inputs, output_shape, input_shapes, payload);
}

} // namespace torch_xla
4 changes: 4 additions & 0 deletions torch_xla/csrc/xla_lower_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ xla::XlaOp BuildCdistForward(xla::XlaOp x1, xla::XlaOp x2, xla::XlaOp p,

xla::XlaOp BuildCustomSharding(const xla::XlaOp& input);

xla::XlaOp BuildTpuCustomCall(const std::vector<xla::XlaOp>& inputs,
const xla::Shape& output_shape,
const std::string& payload);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_XLA_LOWER_UTIL_H_

0 comments on commit 56db8f2

Please sign in to comment.