Skip to content

Commit

Permalink
Code-gen LeakyRelu, LeakyReluBackward again (#4468)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 authored and ManfeiBai committed Jan 19, 2023
1 parent fb42c60 commit 663dcf8
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 155 deletions.
21 changes: 11 additions & 10 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1405,21 +1405,22 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::kthvalue(
bridge::AtenFromXlaTensor(std::get<1>(results)));
}

at::Tensor XLANativeFunctions::leaky_relu(const at::Tensor& self,
const at::Scalar& negative_slope) {
TORCH_LAZY_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu(
bridge::GetXlaTensor(self), negative_slope.to<double>()));
}

at::Tensor XLANativeFunctions::leaky_relu_backward(
const at::Tensor& grad_output, const at::Tensor& self,
const at::Scalar& negative_slope, bool self_is_result) {
TORCH_LAZY_FN_COUNTER("xla::");
XLA_CHECK(!self_is_result || negative_slope.to<double>() >= 0.0);
return bridge::AtenFromXlaTensor(tensor_methods::leaky_relu_backward(
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
negative_slope.to<double>()));
auto common_device = torch_xla::bridge::GetXlaDevice(self);
XLA_CHECK(common_device);
auto node_negative_slope =
torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
negative_slope, *common_device);
torch::lazy::NodePtr node = torch::lazy::MakeNode<LeakyReluBackward>(
bridge::GetXlaTensor(grad_output)->GetIrValue(),
bridge::GetXlaTensor(self)->GetIrValue(), node_negative_slope,
self_is_result);
return torch_xla::bridge::AtenFromXlaTensor(
torch_xla::XLATensor::Create(std::move(node), *common_device));
}

at::Tensor XLANativeFunctions::lerp(const at::Tensor& self,
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,
return xla::Select(Between(input, min_val, max_val), grad_output, zero);
}

xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope_value) {
return BuildLeakyReluBackward(input, input, negative_slope_value);
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope) {
return BuildLeakyReluBackward(input, input, negative_slope);
}

std::vector<xla::XlaOp> BuildRrelu(xla::XlaOp input, const at::Scalar& lower,
Expand All @@ -188,7 +188,9 @@ std::vector<xla::XlaOp> BuildRrelu(xla::XlaOp input, const at::Scalar& lower,
noise = xla::Select(xla::Gt(input, zero), one, slope);
output = input * noise;
} else {
double negative_slope = (lower.to<double>() + upper.to<double>()) / 2;
xla::XlaOp negative_slope =
XlaHelpers::ScalarValue((lower.to<double>() + upper.to<double>()) / 2,
shape.element_type(), input.builder());
noise = xla::Broadcast(zero, shape.dimensions());
output = BuildLeakyRelu(input, negative_slope);
}
Expand All @@ -214,12 +216,10 @@ xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input,
}

xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
double negative_slope_value) {
xla::XlaOp negative_slope) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
negative_slope = MaybeConvertTo(negative_slope, input_shape.element_type());
xla::XlaOp zero = xla::Zero(input.builder(), input_shape.element_type());
xla::XlaOp negative_slope = XlaHelpers::ScalarValue(
negative_slope_value, input_shape.element_type(), input.builder());
return xla::Select(xla::Gt(input, zero), grad_output,
negative_slope * grad_output);
}
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/elementwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input,

// Computes the leaky rectified linear unit:
// LeakyReLU(x) = max(0, input) + negative_slope ∗ min(0, input).
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope);
xla::XlaOp BuildLeakyRelu(xla::XlaOp input, xla::XlaOp negative_slope);

xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input,
double negative_slope_value);
xla::XlaOp negative_slope);

// Computes the sigmoid function using Tanh
// Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5
Expand Down
30 changes: 0 additions & 30 deletions torch_xla/csrc/ops/leaky_relu.cpp

This file was deleted.

25 changes: 0 additions & 25 deletions torch_xla/csrc/ops/leaky_relu.h

This file was deleted.

36 changes: 0 additions & 36 deletions torch_xla/csrc/ops/leaky_relu_backward.cpp

This file was deleted.

26 changes: 0 additions & 26 deletions torch_xla/csrc/ops/leaky_relu_backward.h

This file was deleted.

15 changes: 15 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,21 @@ torch_xla::XlaOpVector Isnan::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::IsNan(xla_input), loctx);
}

torch_xla::XlaOpVector LeakyRelu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildLeakyRelu(xla_input, negative_slope), loctx);
}

torch_xla::XlaOpVector LeakyReluBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
xla::XlaOp negative_slope = loctx->GetOutputOp(operand(2));
return ReturnOp(
BuildLeakyReluBackward(xla_grad_output, xla_input, negative_slope),
loctx);
}

torch_xla::XlaOpVector Logdet::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::LogDet(xla_input), loctx);
Expand Down
24 changes: 24 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,30 @@ xla::Shape IsnanOutputShape(const torch::lazy::Value& input) {
return isnan_shape;
}

xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands";
return BuildLeakyRelu(operands[0], operands[1]);
};
return InferOutputShape({GetXlaShape(input), GetXlaShape(negative_slope)},
lower_for_shape_fn);
}

xla::Shape LeakyReluBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope, bool self_is_result) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
XLA_CHECK_EQ(operands.size(), 3) << "Unexpected number of operands";
return BuildLeakyReluBackward(operands[0], operands[1], operands[2]);
};
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input),
GetXlaShape(negative_slope)},
lower_for_shape_fn);
}

xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other) {
auto lower_for_shape_fn =
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ xla::Shape InverseOutputShape(const torch::lazy::Value& input);

xla::Shape IsnanOutputShape(const torch::lazy::Value& input);

xla::Shape LeakyReluOutputShape(const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope);

xla::Shape LeakyReluBackwardOutputShape(
const torch::lazy::Value& grad_output, const torch::lazy::Value& input,
const torch::lazy::Value& negative_slope, bool self_is_result);

xla::Shape LeScalarOutputShape(const torch::lazy::Value& self,
const torch::lazy::Value& other);

Expand Down
14 changes: 0 additions & 14 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
#include "torch_xla/csrc/ops/index_select.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/ops/kth_value.h"
#include "torch_xla/csrc/ops/leaky_relu.h"
#include "torch_xla/csrc/ops/leaky_relu_backward.h"
#include "torch_xla/csrc/ops/linear_interpolation.h"
#include "torch_xla/csrc/ops/linspace.h"
#include "torch_xla/csrc/ops/log_softmax.h"
Expand Down Expand Up @@ -1407,18 +1405,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output,
grad_output->GetIrValue(), input->GetIrValue(), min_val, max_val));
}

XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope) {
return input->CreateFrom(
torch::lazy::MakeNode<LeakyRelu>(input->GetIrValue(), negative_slope));
}

XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
double negative_slope) {
return grad_output->CreateFrom(torch::lazy::MakeNode<LeakyReluBackward>(
grad_output->GetIrValue(), input->GetIrValue(), negative_slope));
}

XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
const XLATensorPtr& weight) {
return input->CreateFrom(
Expand Down
5 changes: 0 additions & 5 deletions torch_xla/csrc/tensor_methods.h
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,6 @@ XLATensorPtr hardtanh_backward(const XLATensorPtr& grad_output,
const at::Scalar& min_val,
const at::Scalar& max_val);

XLATensorPtr leaky_relu(const XLATensorPtr& input, double negative_slope);
XLATensorPtr leaky_relu_backward(const XLATensorPtr& grad_output,
const XLATensorPtr& input,
double negative_slope);

XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
const XLATensorPtr& weight);
XLATensorPtr lerp(const XLATensorPtr& input, const XLATensorPtr& end,
Expand Down
3 changes: 2 additions & 1 deletion xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ full_codegen:
- hardswish_backward
- inverse
- isnan
- leaky_relu
- le.Scalar
- le.Tensor
- logdet
Expand Down Expand Up @@ -92,6 +93,7 @@ ir_gen:
- bitwise_and.Tensor
- bitwise_or.Tensor
- bitwise_xor.Tensor
- leaky_relu_backward
supported:
- __ilshift__.Scalar
- __ilshift__.Tensor
Expand Down Expand Up @@ -192,7 +194,6 @@ supported:
- index_select
- kl_div
- kthvalue
- leaky_relu
- leaky_relu_backward
- lerp.Scalar
- lerp.Tensor
Expand Down

0 comments on commit 663dcf8

Please sign in to comment.