Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neg bit #56058

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
098c14d
Make negative operation a view
anjali411 Apr 14, 2021
f40925c
Update on "Make negative operation a view"
anjali411 Apr 19, 2021
2eefa77
Update on "Make negative operation a view"
anjali411 Apr 19, 2021
6eee5e5
Update on "Make negative operation a view"
anjali411 Apr 22, 2021
d163ce4
Update on "Make negative operation a view"
anjali411 Apr 22, 2021
85cc2c0
Update on "Make negative operation a view"
anjali411 Apr 22, 2021
2f9c273
Update on "Make negative operation a view"
anjali411 Apr 23, 2021
061bc82
Update on "Add neg bit "
anjali411 Apr 26, 2021
f18e693
Update on "Add neg bit "
anjali411 Apr 26, 2021
9dd3e81
Update on "Add neg bit "
anjali411 Apr 27, 2021
3127413
Update on "Add neg bit "
anjali411 Apr 27, 2021
353c036
Update on "Add neg bit "
anjali411 Apr 27, 2021
60aaff0
Update on "Add neg bit "
anjali411 Apr 27, 2021
3a2ac02
Update on "Add neg bit "
anjali411 Apr 27, 2021
4ada4d0
Update on "Add neg bit "
anjali411 Apr 27, 2021
8175de6
Update on "Add neg bit "
anjali411 Apr 28, 2021
f33253b
Update on "Add neg bit "
anjali411 Apr 29, 2021
83759d3
Update on "Add neg bit "
anjali411 Apr 29, 2021
d7e2de0
Update on "Add neg bit "
anjali411 Apr 29, 2021
6dd45b1
Update on "Add neg bit "
anjali411 Apr 30, 2021
3c3e637
Update on "Add neg bit "
anjali411 Apr 30, 2021
31b663c
Update on "Add neg bit "
anjali411 Apr 30, 2021
9e23b85
Update on "Add neg bit "
anjali411 Apr 30, 2021
087f102
Update on "Add neg bit "
anjali411 May 5, 2021
2f4ca30
Update on "Add neg bit "
anjali411 May 5, 2021
fd7c834
Update on "Add neg bit "
anjali411 May 5, 2021
aa30ced
Update on "Add neg bit "
anjali411 May 6, 2021
2405543
Update on "Add neg bit "
anjali411 May 10, 2021
6d747b8
Update on "Add neg bit "
anjali411 May 11, 2021
fecd435
Update on "Add neg bit "
anjali411 May 11, 2021
ec62013
Update on "Add neg bit "
anjali411 Jun 3, 2021
1fd337d
Update on "Add neg bit "
anjali411 Jun 3, 2021
0e2e89d
Update on "Add neg bit "
anjali411 Jun 3, 2021
6347641
Update on "Add neg bit "
anjali411 Jun 4, 2021
4e1ca75
Update on "Add neg bit "
anjali411 Jun 4, 2021
6d736b8
Update on "Add neg bit "
anjali411 Jun 7, 2021
fc6b8ac
Update on "Add neg bit "
anjali411 Jun 23, 2021
27e02ba
Update on "Add neg bit "
anjali411 Jul 8, 2021
a0d9e33
Update on "Add neg bit "
anjali411 Jul 9, 2021
440e420
Update on "Add neg bit "
anjali411 Jul 9, 2021
c112d2e
Update on "Add neg bit "
anjali411 Jul 9, 2021
c9e635f
Update on "Add neg bit "
anjali411 Jul 9, 2021
4f3d1de
Update on "Add neg bit "
anjali411 Jul 9, 2021
1fb249f
Update on "Add neg bit "
anjali411 Jul 9, 2021
f839bf0
Update on "Add neg bit "
anjali411 Jul 11, 2021
2ddb736
Update on "Add neg bit "
anjali411 Jul 12, 2021
c75f42a
Update on "Add neg bit "
anjali411 Jul 12, 2021
9df93ff
Update on "Add neg bit "
anjali411 Jul 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1177,10 +1177,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
TRIVIAL_OP(imag)
TRIVIAL_OP(real);
TRIVIAL_OP(view_as_real);
TRIVIAL_OP(_view_as_real_physical);
TRIVIAL_OP(conj);
TRIVIAL_OP(_conj);
TRIVIAL_OP(resolve_conj);
TRIVIAL_OP(resolve_neg);
m.impl("view_as_complex", view_as_complex_batching_rule);
#undef TRIVIAL

Expand Down
125 changes: 16 additions & 109 deletions aten/src/ATen/ConjugateFallback.cpp
Original file line number Diff line number Diff line change
@@ -1,118 +1,26 @@
#include <ATen/ATen.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>
#include <c10/util/irange.h>
#include <torch/library.h>
#include <ATen/native/MathBitsFallback.h>

namespace at {

void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Out-of-place operation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
// Materialize other inputs as in (1).
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// conjugating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.

const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;

c10::optional<bool> is_write;
for (const auto i : c10::irange(num_arguments)) {
const auto& alias_info = arguments[i].alias_info();
// Three possible states:
// 1. alias_info has no value --> out-of-place operation
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
if (alias_info.has_value()) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for conjugate fallback: ", op.schema().name(),
"Conjugate fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
struct ConjFallback : MathOpFallback {
ConjFallback() : MathOpFallback(DispatchKey::Conjugate, "conjugate") {}
bool is_bit_set(const Tensor& tensor) override {
return tensor.is_conj();
}

if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle conjugation
// correctly by propagating the Conjugate dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);
return;
void _set_bit(const Tensor& tensor, bool value) override {
return tensor._set_conj(value);
}

// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;

for (const auto i : c10::irange(num_arguments)) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// View operations were already filtered above, so only in-place/out= operations should get here.
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
if (ivalue.isTensor()) {
auto* impl = ivalue.unsafeToTensorImpl();
if (!impl->is_conj()) {
continue;
}

auto tensor = std::move(ivalue).toTensor();
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), "Conjugate Fallback does not support meta tensors.");
if (mut_arg) {
// TODO: This is a waste if the argument is write only
tensor._set_conj(false);
at::conj_physical_(tensor);
mutable_inputs.emplace_back(tensor);
} else {
tensor = at::resolve_conj(tensor);
}
(*stack)[stack_start + i] = std::move(tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
if (mut_arg) {
for(const auto j : c10::irange(tensors.size())) {
Tensor t = tensors[j];
t._set_conj(false);
at::conj_physical_(t);
mutable_inputs.emplace_back(t);
}
} else {
for(const auto j : c10::irange(tensors.size())) {
tensors[j] = at::resolve_conj(tensors[j]);
}
}
(*stack)[stack_start + i] = std::move(tensors);
}
Tensor resolve_bit(const Tensor& tensor) override {
return at::resolve_conj(tensor);
}
Tensor& math_op_(Tensor& tensor) override {
return at::conj_physical_(tensor);
}
};


op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::Conjugate), stack);

for (auto& mutable_input : mutable_inputs) {
at::conj_physical_(mutable_input);
mutable_input._set_conj(true);
}
void conjugateFallback(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
ConjFallback object;
object.fallback_impl(op, dispatch_keys, stack);
}

TORCH_LIBRARY_IMPL(_, Conjugate, m) {
Expand Down Expand Up @@ -142,7 +50,6 @@ TORCH_LIBRARY_IMPL(aten, Conjugate, m) {
m.impl("size.int", torch::CppFunction::makeFallthrough());
m.impl("size.Dimname", torch::CppFunction::makeFallthrough());
m.impl("is_complex", torch::CppFunction::makeFallthrough());
m.impl("_view_as_real_physical", torch::CppFunction::makeFallthrough());
m.impl("view_as_real", torch::CppFunction::makeFallthrough());
m.impl("imag", torch::CppFunction::makeFallthrough());
m.impl("real", torch::CppFunction::makeFallthrough());
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ _(aten, conj) \
_(aten, conj_physical) \
_(aten, conj_physical_) \
_(aten, resolve_conj) \
_(aten, resolve_neg) \
_(aten, complex) \
_(aten, copysign) \
_(aten, polar) \
Expand Down Expand Up @@ -768,7 +769,6 @@ _(aten, zeros_like) \
_(aten, real) \
_(aten, imag) \
_(aten, view_as_real) \
_(aten, _view_as_real_physical) \
_(aten, view_as_complex) \
/* nothing */

Expand Down
18 changes: 9 additions & 9 deletions aten/src/ATen/native/ComplexHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,8 @@ inline DimVector computeStrideForViewAsReal(IntArrayRef oldstride) {
return res;
}

// expects as input a complex tensor and returns back a tensor
// with corresponding real dtype containing the complex values
// in the last two dimensions
Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return native::_view_as_real_physical(self);
}

Tensor _view_as_real_physical(const Tensor& self) {
TORCH_CHECK(self.is_complex(), "view_as_real_physical is only supported for complex tensors");
TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
auto old_sizes = self.sizes();
DimVector new_sizes(old_sizes.size() + 1);
std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
Expand All @@ -53,6 +45,14 @@ Tensor _view_as_real_physical(const Tensor& self) {
return real_tensor;
}

// expects as input a complex tensor and returns back a tensor
// with corresponding real dtype containing the complex values
// in the last two dimensions
Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return _view_as_real_physical(self);
}

inline DimVector computeStrideForViewAsComplex(IntArrayRef oldstride) {
const int64_t dim = oldstride.size();
TORCH_CHECK(oldstride[dim-1] == 1, "Tensor must have a last dimension with stride 1");
Expand Down
137 changes: 137 additions & 0 deletions aten/src/ATen/native/MathBitsFallback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
#include <ATen/ATen.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/library.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/NativeFunctions.h>

namespace at {

// This fallback should only be used for operations that are self inverse and have a corresponding tensor
// bit (internally implemented using DispatchKey) to maintain the state on tensor using tensor bit.
// Currently there are two tensor bits that trigger this fallback: conjugate bit and negative bit.
// Conjugate bit is set on a tensor when `.conj()` is called and neg bit is set on a tensor when `.conj().imag` is called.

struct MathOpFallback {
MathOpFallback(DispatchKey key_, string op_name_) : key(key_), op_name(op_name_) {}
virtual bool is_bit_set(const Tensor&) = 0;
virtual void _set_bit(const Tensor&, bool) = 0;
// materializes the bit, i.e., returns a new tensor tensor containing the true output
// (after performing the math operation corresponding to the tensor bit) if the bit is set to 1
// else returns self.
virtual Tensor resolve_bit(const Tensor&) = 0;
// in-place operation corresponding to the math op represented by the bit. Im the future if this class
// is generalized for ops that are not self inverse, then this must be replaced by op_inverse_inplace
virtual Tensor& math_op_(Tensor&) = 0;
void fallback_impl(const c10::OperatorHandle& op, DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
// Situations to handle:
// 1. Out-of-place operation. Easy: materialize all inputs and
// call it a day.
// 2. Inplace operation. Desugar x.add_(2) into x.conj_().add_(2).conj_().
// Materialize other inputs as in (1).
// 3. out= operation. Desugar add(x, 2, out=y) into y.copy_(add(x, 2))
// Materialize other inputs as in (1).
//
// It is important to be able to tell if we READ from an argument and if we
// WRITE from an argument. Conservative approach is to assume that we always
// READ from an argument, but in out-of-place operations you can skip
// conjugating inputs on entry that never get used. In current schema we
// can't easily tell if inplace situation has happened, so don't do it.

const auto& arguments = op.schema().arguments();
const auto num_arguments = arguments.size();
const auto stack_start = stack->size() - num_arguments;

c10::optional<bool> is_write;
for (int64_t i = 0; i < num_arguments; ++i) {
// Three possible states:
// 1. alias_info has no value --> out-of-place operation
// 2. alias_info does have a value, alias_info->is_write=True --> in-place or out= operation
// 3. alias_info does have a value, alias_info->is_write=False --> view operation
const auto& alias_info = arguments[i].alias_info();
if (alias_info.has_value()) {
if (is_write.has_value()) {
TORCH_CHECK(*is_write == alias_info->isWrite(),
"Unsupported operator for ", op_name, " fallback: ", op.schema().name(),
op_name, " fallback doesn't work for operators with a mix "
"mutable and non-mutable inputs that alias with outputs, "
"this must be implemented manually. "
"If you got this error on a core op, please report a bug to PyTorch.");
} else {
is_write = alias_info->isWrite();
}
}
}

if (is_write.has_value() && !*is_write) {
// We assume that view operators automatically handle the math bit
// correctly by propagating the dispatch key in key_set.
// This is not necessarily always right, so you should test these cases.
op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);
return;
}

// Mutable inputs to be tracked separately
std::vector<Tensor> mutable_inputs;

for (int64_t i = 0; i < num_arguments; ++i) {
auto& ivalue = (*stack)[stack_start + i];
if (!(ivalue.isTensor() || ivalue.isTensorList())) {
continue;
}
const auto& argument = arguments[i];
bool mut_arg = false;
if (argument.alias_info()) {
// Was already tested by is_write loop above
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(argument.alias_info()->isWrite());
mut_arg = true;
}
if (ivalue.isTensor()) {
if (!is_bit_set(ivalue.toTensor())) {
continue;
}

auto tensor = std::move(ivalue).toTensor();
TORCH_CHECK_NOT_IMPLEMENTED(!tensor.is_meta(), op_name, " fallback does not support meta tensors.");
if (mut_arg) {
// TODO: This is a waste if the argument is write only
_set_bit(tensor, false);
math_op_(tensor);
mutable_inputs.emplace_back(tensor);
} else {
tensor = resolve_bit(tensor);
}
(*stack)[stack_start + i] = std::move(tensor);
} else if (ivalue.isTensorList()) {
auto tensors = std::move(ivalue).toTensorList();
if (mut_arg) {
for(const auto j : c10::irange(tensors.size())) {
Tensor t = tensors[j];
_set_bit(t, false);
math_op_(t);
mutable_inputs.emplace_back(t);
}
} else {
for(const auto j : c10::irange(tensors.size())) {
tensors[j] = resolve_bit(tensors[j]);
}
}
(*stack)[stack_start + i] = std::move(tensors);
}
}

op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, key), stack);

for (auto& mutable_input : mutable_inputs) {
math_op_(mutable_input);
_set_bit(mutable_input, true);
}
}

virtual ~MathOpFallback() = default;

DispatchKey key;
string op_name;
};

} // namespace at
Loading