Skip to content

Commit

Permalink
Lower RandPerm (#6315)
Browse files Browse the repository at this point in the history
Co-authored-by: Mason Chang <masonchang@google.com>
  • Loading branch information
2 people authored and bhavya01 committed Apr 22, 2024
1 parent 4c785a6 commit ec4b301
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 10 deletions.
1 change: 1 addition & 0 deletions codegen/xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ supported:
- random_
- random_.from
- random_.to
- randperm
- reflection_pad2d
- reflection_pad2d_backward
- remainder.Scalar
Expand Down
49 changes: 41 additions & 8 deletions test/cpp/test_aten_xla_tensor_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1484,14 +1484,47 @@ TEST_F(AtenXlaTensorTest, TestNativeDropoutZeroProbability) {

TEST_F(AtenXlaTensorTest, TestRandperm) {
int n = 5;
torch::Tensor shuffle = torch::randperm(
n, torch::TensorOptions(torch::kLong).device(torch::kXLA));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);
std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
ExpectCounterNotChanged("aten::(?!randperm.generator_out).*",
cpp_test::GetIgnoredCounters());
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterNotChanged("aten::.*", cpp_test::GetIgnoredCounters());
ExpectCounterChanged("xla::randperm", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestRandpermZeroDoesntCrash) {
int n = 0;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle =
torch::randperm(n, torch::TensorOptions(torch::kLong).device(device));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.empty());
});
}

TEST_F(AtenXlaTensorTest, TestRandpermCPUFallback) {
int n = 5;
ForEachDevice([&](const torch::Device& device) {
torch::Tensor shuffle = torch::randperm(
n,
torch::TensorOptions(torch::kLong).device(device).pinned_memory(true));
torch::Tensor shuffle_cpu = CopyToDevice(shuffle, torch::kCPU);

std::vector<int64_t> shuffle_data(shuffle_cpu.data_ptr<int64_t>(),
shuffle_cpu.data_ptr<int64_t>() + n);
EXPECT_TRUE(shuffle_data.size() == n && xla::IsPermutation(shuffle_data));
});

ExpectCounterChanged("aten::.*", cpp_test::GetIgnoredCounters());
}

TEST_F(AtenXlaTensorTest, TestSlice) {
Expand Down
13 changes: 11 additions & 2 deletions test/test_core_aten_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3386,11 +3386,20 @@ def test_aten_prod_dim_int_1(self):
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.prod.dim_int, args, kwargs)

@unittest.skip
# Due to the way randperm isn't on device, we manually assert checks here instead of using
# the existing test harness.
def test_aten_randperm_0(self):
args = (20,)
kwargs = dict()
run_export_and_compare(self, torch.ops.aten.randperm, args, kwargs)
pytorch = torch.randperm(20)

xla = torch.randperm(20, device=xm.xla_device())
xla_detached = xla.detach().cpu()

# Check equal lengths and that the sorted sets are equal. Since these numbers are randomly
# generated there's no way to check that pytorch == pytorch/xla.
self.assertEqual(len(pytorch), len(xla))
self.assertEqual(sorted(set(pytorch)), sorted(set(xla_detached)))

def test_aten_reciprocal_0(self):
args = (torch.randn((10, 10)).to(torch.float32),)
Expand Down
26 changes: 26 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2498,6 +2498,32 @@ at::Tensor& XLANativeFunctions::random_(
return self;
}

at::Tensor XLANativeFunctions::randperm(int64_t n,
c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout,
c10::optional<at::Device> device,
c10::optional<bool> pin_memory) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");

// Only support the basic version of randperm(int64_t) to start. If there are
// any other parameters, fallback to CPU.
bool fallback_to_cpu = false;
fallback_to_cpu |= layout.has_value();
fallback_to_cpu |= pin_memory.has_value() && pin_memory.value() == true;
fallback_to_cpu |= dtype.value() != at::ScalarType::Long;
fallback_to_cpu |= n == 0;

if (fallback_to_cpu) {
return at::native::call_fallback_fn<&xla_cpu_fallback,
ATEN_OP(randperm)>::call(n, dtype,
layout, device,
pin_memory);
}

return bridge::AtenFromXlaTensor(tensor_methods::randperm(
n, GetXlaDeviceOrCurrent(device), at::ScalarType::Long));
}

at::Tensor XLANativeFunctions::reflection_pad2d(const at::Tensor& self,
at::IntArrayRef padding) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
Expand Down
72 changes: 72 additions & 0 deletions torch_xla/csrc/ops/randperm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "torch_xla/csrc/ops/randperm.h"

#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "tsl/platform/stacktrace.h"
#include "xla/client/lib/loops.h"
#include "xla/shape_util.h"

namespace torch_xla {
namespace {

using namespace xla;

xla::Shape NodeOutputShape(int64_t n) {
return xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {n});
}

XlaOp Swap(XlaOp input, XlaOp i, XlaOp j) {
XlaOp i_value = xla::DynamicSlice(input, {i}, /*slice_sizes=*/{1});
XlaOp j_value = xla::DynamicSlice(input, {j}, /*slice_sizes=*/{1});

XlaOp write_i = xla::DynamicUpdateSlice(input, j_value, {i});
XlaOp write_j = xla::DynamicUpdateSlice(write_i, i_value, {j});

return write_j;
}

StatusOr<std::vector<XlaOp>> LoopBodyFn(XlaOp i, absl::Span<const XlaOp> values,
XlaBuilder* builder) {
XlaOp input_array = values[0];
XlaOp upper_bound_exclusive = values[1];

XlaOp target_index = xla::RngUniform(
i, upper_bound_exclusive,
ShapeUtil::MakeShape(xla::PrimitiveType::S64, /*dimensions=*/{1}));

XlaOp swapped_array = Swap(input_array, i, target_index);
return std::vector<XlaOp>{swapped_array, upper_bound_exclusive};
}

} // namespace

RandPerm::RandPerm(int64_t n, const at::ScalarType dtype,
const at::Layout layout, const at::Device device,
bool pin_memory)
: XlaNode(torch::lazy::OpKind(at::aten::randperm), /*operands=*/{},
[&]() { return NodeOutputShape(n); }, /*num_outputs=*/1,
torch::lazy::MHash(n)),
n_(n) {}

// Fischer Yates Shuffle.
XlaOpVector RandPerm::Lower(LoweringContext* lotcx) const {
xla::XlaBuilder* builder = lotcx->builder();
auto init_tensor = xla::Iota(lotcx->builder(), xla::PrimitiveType::S64, n_);

auto upper_bound_exclusive = xla::ConstantLiteral(
lotcx->builder(), xla::LiteralUtil::CreateR0<int64_t>(n_));
auto fischer_yates_loop = xla::ForEachIndex(
/*num_iterations=*/n_ - 1, xla::PrimitiveType::S64, &LoopBodyFn,
{init_tensor, upper_bound_exclusive}, "Fischer-Yates-Shuffle", builder);

return ReturnOp(fischer_yates_loop.value()[0], lotcx);
}

std::string RandPerm::ToString() const {
std::stringstream ss;
ss << XlaNode::ToString() << ", n=" << n_;
return ss.str();
}

} // namespace torch_xla
24 changes: 24 additions & 0 deletions torch_xla/csrc/ops/randperm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_
#define XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {

class RandPerm : public XlaNode {
public:
RandPerm(int64_t n, const at::ScalarType dtype, const at::Layout layout,
const at::Device device, bool pin_memory);

XlaOpVector Lower(LoweringContext* loctx) const override;
std::string ToString() const override;

private:
int64_t n_;
};

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_
11 changes: 11 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
#include "torch_xla/csrc/ops/put.h"
#include "torch_xla/csrc/ops/qr.h"
#include "torch_xla/csrc/ops/quant_tensor.h"
#include "torch_xla/csrc/ops/randperm.h"
#include "torch_xla/csrc/ops/recv.h"
#include "torch_xla/csrc/ops/reduce_scatter.h"
#include "torch_xla/csrc/ops/reflection_pad2d.h"
Expand Down Expand Up @@ -2261,6 +2262,16 @@ void random_(XLATensorPtr& input, int64_t from, int64_t to) {
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
}

XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
at::ScalarType scalar_type) {
// These are all PyTorch defaults. PyTorch/XLA doesn't support non default
// params here yet.
torch::lazy::NodePtr node = torch::lazy::MakeNode<RandPerm>(
n, at::ScalarType::Long, at::Layout::Strided, at::DeviceType::XLA,
/*pin_memory=*/false);
return XLATensor::Create(node, device, scalar_type);
}

XLATensorPtr reflection_pad2d(const XLATensorPtr& input,
std::vector<int64_t> padding) {
return input->CreateFrom(torch::lazy::MakeNode<ReflectionPad2d>(
Expand Down

0 comments on commit ec4b301

Please sign in to comment.