-
Notifications
You must be signed in to change notification settings - Fork 480
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Mason Chang <masonchang@google.com>
- Loading branch information
Showing
7 changed files
with
186 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#include "torch_xla/csrc/ops/randperm.h" | ||
|
||
#include "torch_xla/csrc/lowering_context.h" | ||
#include "torch_xla/csrc/ops/infer_output_shape.h" | ||
#include "torch_xla/csrc/ops/xla_ops.h" | ||
#include "tsl/platform/stacktrace.h" | ||
#include "xla/client/lib/loops.h" | ||
#include "xla/shape_util.h" | ||
|
||
namespace torch_xla { | ||
namespace { | ||
|
||
using namespace xla; | ||
|
||
xla::Shape NodeOutputShape(int64_t n) { | ||
return xla::ShapeUtil::MakeShape(xla::PrimitiveType::S64, {n}); | ||
} | ||
|
||
XlaOp Swap(XlaOp input, XlaOp i, XlaOp j) { | ||
XlaOp i_value = xla::DynamicSlice(input, {i}, /*slice_sizes=*/{1}); | ||
XlaOp j_value = xla::DynamicSlice(input, {j}, /*slice_sizes=*/{1}); | ||
|
||
XlaOp write_i = xla::DynamicUpdateSlice(input, j_value, {i}); | ||
XlaOp write_j = xla::DynamicUpdateSlice(write_i, i_value, {j}); | ||
|
||
return write_j; | ||
} | ||
|
||
StatusOr<std::vector<XlaOp>> LoopBodyFn(XlaOp i, absl::Span<const XlaOp> values, | ||
XlaBuilder* builder) { | ||
XlaOp input_array = values[0]; | ||
XlaOp upper_bound_exclusive = values[1]; | ||
|
||
XlaOp target_index = xla::RngUniform( | ||
i, upper_bound_exclusive, | ||
ShapeUtil::MakeShape(xla::PrimitiveType::S64, /*dimensions=*/{1})); | ||
|
||
XlaOp swapped_array = Swap(input_array, i, target_index); | ||
return std::vector<XlaOp>{swapped_array, upper_bound_exclusive}; | ||
} | ||
|
||
} // namespace | ||
|
||
RandPerm::RandPerm(int64_t n, const at::ScalarType dtype, | ||
const at::Layout layout, const at::Device device, | ||
bool pin_memory) | ||
: XlaNode(torch::lazy::OpKind(at::aten::randperm), /*operands=*/{}, | ||
[&]() { return NodeOutputShape(n); }, /*num_outputs=*/1, | ||
torch::lazy::MHash(n)), | ||
n_(n) {} | ||
|
||
// Fischer Yates Shuffle. | ||
XlaOpVector RandPerm::Lower(LoweringContext* lotcx) const { | ||
xla::XlaBuilder* builder = lotcx->builder(); | ||
auto init_tensor = xla::Iota(lotcx->builder(), xla::PrimitiveType::S64, n_); | ||
|
||
auto upper_bound_exclusive = xla::ConstantLiteral( | ||
lotcx->builder(), xla::LiteralUtil::CreateR0<int64_t>(n_)); | ||
auto fischer_yates_loop = xla::ForEachIndex( | ||
/*num_iterations=*/n_ - 1, xla::PrimitiveType::S64, &LoopBodyFn, | ||
{init_tensor, upper_bound_exclusive}, "Fischer-Yates-Shuffle", builder); | ||
|
||
return ReturnOp(fischer_yates_loop.value()[0], lotcx); | ||
} | ||
|
||
std::string RandPerm::ToString() const { | ||
std::stringstream ss; | ||
ss << XlaNode::ToString() << ", n=" << n_; | ||
return ss.str(); | ||
} | ||
|
||
} // namespace torch_xla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ | ||
#define XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ | ||
|
||
#include <vector> | ||
|
||
#include "torch_xla/csrc/ir.h" | ||
|
||
namespace torch_xla { | ||
|
||
class RandPerm : public XlaNode { | ||
public: | ||
RandPerm(int64_t n, const at::ScalarType dtype, const at::Layout layout, | ||
const at::Device device, bool pin_memory); | ||
|
||
XlaOpVector Lower(LoweringContext* loctx) const override; | ||
std::string ToString() const override; | ||
|
||
private: | ||
int64_t n_; | ||
}; | ||
|
||
} // namespace torch_xla | ||
|
||
#endif // XLA_TORCH_XLA_CSRC_OPS_RANDPERM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters