Skip to content

Commit

Permalink
Add inverse op and test
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG committed Feb 4, 2020
1 parent d9f804a commit 417ba51
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 0 deletions.
1 change: 1 addition & 0 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class ArgTemplate(string.Template):
'gather_out': FuncOpts(),
'kthvalue_out': FuncOpts(),
'index_select_out': FuncOpts(),
'inverse_out' : FuncOpts(),
'log_out': FuncOpts(),
'masked_select_out': FuncOpts(),
'nonzero_out': FuncOpts(),
Expand Down
10 changes: 10 additions & 0 deletions test/cpp/test_aten_xla_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3624,6 +3624,16 @@ TEST_F(AtenXlaTensorTest, TestIndexSelectRank0) {
}
}

TEST_F(AtenXlaTensorTest, TestInverse) {
at::Tensor a = at::randn({5, 5}, at::TensorOptions(at::kFloat));
at::Tensor b = at::inverse(a);
ForEachDevice([&](const Device& device) {
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
at::Tensor xla_b = at::inverse(xla_a);
AllClose(b, xla_b);
});
}

TEST_F(AtenXlaTensorTest, TestExpand) {
torch::Tensor a = torch::rand({3, 4}, torch::TensorOptions(torch::kFloat));
torch::Tensor b = a.expand({2, 3, 4}, /*implicit=*/false);
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,12 @@ at::Tensor AtenXlaType::index_select(const at::Tensor& self, int64_t dim,
bridge::GetXlaTensor(self), dim, bridge::GetXlaTensor(index)));
}

at::Tensor AtenXlaType::inverse(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
XLATensor::inverse(bridge::GetXlaTensor(self)));
}

at::Tensor AtenXlaType::kl_div(const at::Tensor& self, const at::Tensor& target,
int64_t reduction) {
XLA_FN_COUNTER("xla::");
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/aten_xla_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ class AtenXlaType {
static at::Tensor index_select(const at::Tensor& self, int64_t dim,
const at::Tensor& index);

static at::Tensor inverse(const at::Tensor& self);

static at::Tensor kl_div(const at::Tensor& self, const at::Tensor& target,
int64_t reduction);

Expand Down
13 changes: 13 additions & 0 deletions torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "torch_xla/csrc/convert_ops.h"
Expand Down Expand Up @@ -126,4 +127,16 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input,
return result;
}

xla::XlaOp BuildInverse(xla::XlaOp input) {
xla::QRDecompositionResult qr_result =
xla::QRDecomposition(input, /*full_matrices=*/false).ValueOrDie();
xla::XlaOp output =
xla::TriangularSolve(qr_result.r, xla::TransposeInMinorDims(qr_result.q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
return output;
}

} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input,
xla::int64 offset, xla::int64 dim1,
xla::int64 dim2);

xla::XlaOp BuildInverse(xla::XlaOp input);

} // namespace torch_xla
12 changes: 12 additions & 0 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/elementwise.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/matrix.h"
#include "torch_xla/csrc/nll_loss.h"
#include "torch_xla/csrc/ops/arithmetic_ir_ops.h"
#include "torch_xla/csrc/ops/constant.h"
Expand Down Expand Up @@ -633,6 +634,17 @@ NodePtr LogDet(const Value& input) {
std::move(lower_fn));
}

NodePtr Inverse(const Value& input) {
auto lower_fn = [](const Node& node, LoweringContext* loctx) -> XlaOpVector {
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
xla::XlaOp result = BuildInverse(xla_input);
return node.ReturnOp(result, loctx);
};

return GenericOp(OpKind(at::aten::inverse), {input}, input.shape(),
std::move(lower_fn));
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ NodePtr Take(const Value& input, const Value& index);

NodePtr LogDet(const Value& input);

NodePtr Inverse(const Value& input);

} // namespace ops
} // namespace ir
} // namespace torch_xla
2 changes: 2 additions & 0 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ class XLATensor {
static XLATensor index_select(const XLATensor& input, xla::int64 dim,
const XLATensor& index);

static XLATensor inverse(const XLATensor& input);

static XLATensor kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,10 @@ XLATensor XLATensor::index_select(const XLATensor& input, xla::int64 dim,
index_value));
}

XLATensor XLATensor::inverse(const XLATensor& input) {
return input.CreateFrom(ir::ops::Inverse(input.GetIrValue()));
}

XLATensor XLATensor::kl_div_backward(const XLATensor& grad_output,
const XLATensor& input,
const XLATensor& target,
Expand Down

0 comments on commit 417ba51

Please sign in to comment.