Skip to content

Commit

Permalink
Make negative operation a view
Browse files Browse the repository at this point in the history
ghstack-source-id: 113de6e3ce6534aa134da5ab61cc095e72bc804f
Pull Request resolved: #56058
  • Loading branch information
anjali411 committed Apr 14, 2021
1 parent 795e68a commit f9dc1a9
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 39 deletions.
122 changes: 122 additions & 0 deletions aten/src/ATen/native/NegateFallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>

namespace at {

void negationFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Purely functional situation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.neg_().add_(2).neg_().
// Materialize other inputs as in (1).
// 3. Out-of-place operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// negating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.

// std::cerr << "neg fallback " << op.schema().name() << "\n";

const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;

c10::optional<bool> is_write;
for (int64_t i = 0; i < num_arguments; ++i) {
const auto& alias_info = arguments[i].alias_info();
if (alias_info.has_value()) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for negation fallback: ", op.schema().name(),
"Negation fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}

if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle negation
// correctly by propagating the Negative dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Negative), stack);
return;
}

// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;

for (int64_t i = 0; i < num_arguments; ++i) {
auto& ivalue = (*stack)[stack_start + i];
if (!ivalue.isTensor()) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// Was already tested by is_write loop above
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
auto* impl = ivalue.unsafeToTensorImpl();
if (!impl->is_neg()) {
continue;
}

auto tensor = std::move(ivalue).toTensor();
if (mut_arg) {
// TODO: This is a waste if the argument is write only
native::neg_(tensor);
tensor.set_neg(false);
mutable_inputs.emplace_back(tensor);
} else {
tensor = native::resolve_neg(tensor);
}
(*stack)[stack_start + i] = std::move(tensor);
}

op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Negative), stack);

for (auto& mutable_input : mutable_inputs) {
native::neg_(mutable_input);
mutable_input.set_neg(true);
}
}

TORCH_LIBRARY_IMPL(_, Negative, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&negationFallback>());
}

TORCH_LIBRARY_IMPL(aten, Negative, m) {
m.impl("copy_", torch::CppFunction::makeFallthrough());
m.impl("conj", torch::CppFunction::makeFallthrough());
m.impl("empty_like", torch::CppFunction::makeFallthrough());
m.impl("empty.out", torch::CppFunction::makeFallthrough());
m.impl("empty_strided", torch::CppFunction::makeFallthrough());
m.impl("stride.int", torch::CppFunction::makeFallthrough());
m.impl("stride.Dimname", torch::CppFunction::makeFallthrough());
m.impl("size.int", torch::CppFunction::makeFallthrough());
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
m.impl("is_complex", torch::CppFunction::makeFallthrough());
m.impl("is_floating_point", torch::CppFunction::makeFallthrough());
m.impl("view_as_real_physical", torch::CppFunction::makeFallthrough());
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
m.impl("imag", torch::CppFunction::makeFallthrough());
m.impl("real", torch::CppFunction::makeFallthrough());
m.impl("view", torch::CppFunction::makeFallthrough());
m.impl("reshape", torch::CppFunction::makeFallthrough());
m.impl("select", torch::CppFunction::makeFallthrough());
// TODO: need to hit the view functions
}

} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ bool is_conj(const Tensor& self) {
return self.is_conj();
}

bool is_neg(const Tensor& self) {
return self.is_neg();
}

bool is_sparse(const Tensor& self) {
return self.is_sparse();
}
Expand Down
26 changes: 23 additions & 3 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ Tensor real(const Tensor& self) {

Tensor imag(const Tensor& self) {
if (self.is_complex()) {
auto real_tensor = at::view_as_real(self);
return at::select(real_tensor, real_tensor.dim() - 1, 1);
auto real_tensor = at::view_as_real_physical(self);
auto true_real_tensor = self.is_conj() ? real_tensor.neg() : real_tensor;
return at::select(true_real_tensor, real_tensor.dim() - 1, 1);
} else {
TORCH_CHECK(false, "imag is not implemented for tensors with non-complex dtypes.");
}
Expand Down Expand Up @@ -576,7 +577,26 @@ Tensor& neg_out(const Tensor& self, Tensor& result) {
"If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
return unary_op_impl_out(result, self, neg_stub);
}
Tensor neg(const Tensor& self) { return unary_op_impl(self, at::neg_out); }

Tensor resolve_neg(const Tensor& self) {
if (!self.is_neg()) { return self; }
auto result = at::empty_like(self, self.options());
// negation is handled in `copy_()`
return result.copy_(self);
}

Tensor neg(const Tensor& self) {
Tensor self_;
auto impl = c10::make_intrusive<TensorImpl>(
Storage(self.storage()), self.key_set(), self.dtype());
impl->set_storage_offset(self.storage_offset());
impl->set_sizes_and_strides(self.sizes(), self.strides());
impl->set_neg(!self.is_neg());
self_ = Tensor(std::move(impl));
namedinference::propagate_names(self_, self);
return self_;
}

Tensor& neg_(Tensor& self) { return unary_op_impl_(self, at::neg_out); }

Tensor& negative_out(const Tensor& self, Tensor& result) { return at::neg_out(result, self); }
Expand Down
72 changes: 51 additions & 21 deletions aten/src/ATen/native/cpu/CopyKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,62 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
// conditionals into a single jump table. We should have a
// single jump table here; might be worth just writing out the
// dispatch statement by hand instead of using AT_DISPATCH
if (dtype == ScalarType::Half) {
if (iter.tensor(0).is_neg() == iter.tensor(1).is_neg()) {
if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return a; });
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a; });
});
} else if (isComplexType(dtype)) {
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
} else if (dtype == ScalarType::ComplexHalf) {
cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
} else if (isQIntType(dtype)) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a; });
});
});
} else if (isComplexType(dtype)) {
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a; });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a.conj(); });
});
}
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return conj_impl(a); },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a.conj(); });
});
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return a; },
[=](Vec256<scalar_t> a) { return a; });
});
}
} else {
if (dtype == ScalarType::Half) {
cpu_kernel(iter, [=](at::Half a) -> at::Half { return -a; });
} else if (isComplexType(dtype)) {
if (iter.tensor(0).is_conj() == iter.tensor(1).is_conj()) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "copy_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -a; },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a.neg(); });
});
} else {
AT_DISPATCH_COMPLEX_TYPES(dtype, "conj_kernel", [&] {
cpu_kernel_vec(
iter,
[=](scalar_t a) -> scalar_t { return -1 * conj_impl(a); },
[=](Vec256<scalar_t> a) -> Vec256<scalar_t> { return a.neg().conj(); });
});
}
} else {
AT_DISPATCH_ALL_TYPES_AND2(
ScalarType::Bool, ScalarType::BFloat16,dtype, "copy_kernel", [&] {
Expand All @@ -59,6 +88,7 @@ static void copy_kernel(TensorIterator& iter, bool non_blocking) {
[=](Vec256<scalar_t> a) { return a; });
});
}
}
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
using dest_t = scalar_t;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ static void reciprocal_kernel(TensorIterator& iter) {
});
}

// NB: Ignores the negative bit on tensors
static void neg_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "neg_cpu", [&]() {
cpu_kernel_vec(
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/UnarySignKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void logical_not_kernel_cuda(TensorIterator& iter) {
});
}

// NB: Ignores the negative bit on tensors
void neg_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "neg_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,9 @@
- func: resolve_conj(Tensor(a) self) -> Tensor(a)
variants: function, method

- func: resolve_neg(Tensor(a) self) -> Tensor(a)
variants: function, method

- func: acos(Tensor self) -> Tensor
variants: function, method
dispatch:
Expand Down Expand Up @@ -1976,6 +1979,11 @@
device_guard: False
manual_cpp_binding: True

- func: is_neg(Tensor self) -> bool
variants: function, method
device_guard: False
manual_cpp_binding: True

- func: isreal(Tensor self) -> Tensor
variants: function, method

Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/templates/Functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,8 @@ inline bool is_conj(const Tensor& tensor) {
return tensor.is_conj();
}

inline bool is_neg(const Tensor& tensor) {
return tensor.is_neg();
}

}
8 changes: 8 additions & 0 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ class TORCH_API Tensor {
impl_->set_conj(conjugate);
}

inline bool is_neg() const {
return impl_->is_neg();
}

inline void set_neg(bool negative) const {
impl_->set_neg(negative);
}

/// Returns a `Tensor`'s layout.
Layout layout() const noexcept {
return impl_->layout();
Expand Down
6 changes: 4 additions & 2 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ const char* toString(DispatchKey t) {
case DispatchKey::Meta:
return "Meta";

case DispatchKey::InplaceOrView:
return "InplaceOrView";
case DispatchKey::Negative:
return "Negative";
case DispatchKey::Conjugate:
return "Conjugate";
case DispatchKey::InplaceOrView:
return "InplaceOrView";
case DispatchKey::Autograd:
return "Autograd";
case DispatchKey::AutogradCPU:
Expand Down
12 changes: 8 additions & 4 deletions c10/core/DispatchKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ enum class DispatchKey : uint8_t {
// constituent parts.
Named,

// The Negative dispatch key is set for any tensors that need to perform negation
// This is implemented at a dispatch level right before any backends run
Negative,

// The Conjugate dispatch key is set for any tensors that need to perform conjugation
// This is implemented at a dispatch level right before any backends run
Conjugate,

// Note [InplaceOrView key]
// InplaceOrView key is used by inplace or view ops to register a kernel
// that does additional setup for future autograd computation.
Expand Down Expand Up @@ -181,10 +189,6 @@ enum class DispatchKey : uint8_t {
// to view/inplace ops to minimize its perf impact to real models.
InplaceOrView,

// The Conjugate dispatch key is set for any tensors that need to perform conjugation
// This is implemented at a dispatch level right before any backends run
Conjugate,

// Note [Alias Dispatch Key : Autograd]
// All backends are oblivious to autograd; autograd is handled as a
// layer which happens on top of all backends. It inspects the autograd
Expand Down
18 changes: 18 additions & 0 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,24 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
}
}

/**
* Whether or not the tensor should be negated
*/
inline bool is_neg() const {
return key_set_.has(DispatchKey::Negative);
}

/**
* Set whether or not to take the conjugate of the tensor (flip the imaginary bit).
*/
void set_neg(bool value) {
if (value) {
key_set_ = key_set_.add(DispatchKey::Negative);
} else {
key_set_ = key_set_.remove(DispatchKey::Negative);
}
}

/**
* Return the accumulated gradient of a tensor. This gradient is computed
* using forward mode AD.
Expand Down
Loading

0 comments on commit f9dc1a9

Please sign in to comment.