Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul: convert inputs to result type. #7130

Merged
merged 3 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2069,6 +2069,42 @@ def test(f, xshape, ishapes):
for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))

def test_inplace_mul_scalar_different_dtype(self):
# This tests whether the returned output data-type agrees on PyTorch
# and XLA sides.
#
# Technical details: even though we were computing the common data-type
# inside PyTorch/XLA XLANativeFunctions::mul function, we were using it
# just for telling PyTorch what the output data-type would be, i.e. creating
# an IR node of that data-type). Meanwhile, in the XLA side of things,
# it would just promote the tensors using other data-type promotion rules.
#
# In summary, given the expressions below, the problem this test covers is:
#
# >>> t = torch.rand(10, dtype=torch.half)
# >>> s = torch.tensor(5, dtype=torch.double)
# >>> out = t.mul_(s)
#
# out.dtype is torch.float16, but its underlying XLA type (xla::Shape's
# element_type) is F64
#
# See: https://github.com/pytorch/xla/issues/7084

def fn(inp, s):
return inp.mul_(s)

inp = torch.rand(10, dtype=torch.half)
s = torch.tensor(7, dtype=torch.double)

Xinp = inp.to(xm.xla_device())
Xs = s.to(xm.xla_device())

out = fn(inp, s)
Xout = fn(Xinp, Xs)

self.assertEqual(out, Xout.cpu())
self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout))


class MNISTComparator(nn.Module):

Expand Down
179 changes: 174 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/MetaFunctions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/OpMathType.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
Expand Down Expand Up @@ -64,6 +65,171 @@
namespace torch_xla {
namespace {

using XLAInputVector = std::vector<XLATensorPtr>;

// Calls the inner function by spreading inputs in order, and adding the
// common data-type in the end.
template <class InnerFnType, size_t... Ints>
XLATensorPtr CallInner(const InnerFnType& inner, XLAInputVector inputs,
at::ScalarType common_dtype,
std::integer_sequence<size_t, Ints...> seq) {
return inner(inputs[Ints]..., common_dtype);
}

// Computes the number of XLATensorPtr arguments of a given function.
//
// This is used when calling tensor_methods functions, given a list of inputs.
// Specifically, in order to know how many inputs we should get from the list.
template <class T>
struct NumberOfXLATensorArgs {};

template <class... Args>
struct NumberOfXLATensorArgs<XLATensorPtr(Args...)> {
static constexpr size_t value =
(std::is_same_v<XLATensorPtr,
std::remove_cv_t<std::remove_reference_t<Args>>> +
...);
};

// Stateful configuration structure for pre/post-processing the inputs and the
// output.
//
// There are a few checks and preprocessing that PyTorch does, that we are
// mirroring with this class. This should help us get many data-type behavior
// right.
class OpConfig {
public:
using InputVector = std::vector<at::Tensor>;
using ImplFnType =
std::function<XLATensorPtr(const XLAInputVector&, at::ScalarType)>;

// Construct an instance from a function of exactly ImplFnType.
OpConfig(ImplFnType impl) : impl_(impl) {}

// Construct an instance from a function of the following type:
// XLATensorPtr(Tensor..., ScalarType)
//
// This is a convenience for wrapping tensor_methods functions.
template <class InnerFnType>
static OpConfig From(const InnerFnType& inner_impl) {
return OpConfig(
[&](const XLAInputVector& inputs, at::ScalarType common_dtype) {
constexpr size_t num_tensor_args =
NumberOfXLATensorArgs<std::remove_pointer_t<InnerFnType>>::value;
return CallInner(inner_impl, inputs, common_dtype,
std::make_index_sequence<num_tensor_args>{});
});
}

OpConfig& add_input(const at::Tensor& input) {
inputs_.push_back(input);
return *this;
}

OpConfig& cast_inputs_to_common_dtype() {
cast_inputs_to_common_dtype_ = true;
return *this;
}

OpConfig& use_opmathtype_for_compute() {
use_opmathtype_for_compute_ = true;
return *this;
}

// Pre-processes the inputs and post-processes the outputs depending on the
// configured state of this class.
//
// In summary, it will:
// - Compute the common data-type to be used
// - Cast the inputs to the common data-type
// - Cast the inputs to its OpMathType (for computation only)
// - Run the specified impl
// - Cast the output back to the common data-type
at::Tensor run() {
at::ScalarType common_dtype = at::native::result_type(inputs_);
at::ScalarType opmathtype = at::toOpMathType(common_dtype);
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved

// Pre-process the inputs, given the specified configuration and
// common_dtype.
InputVector inputs = maybe_preprocess_inputs(common_dtype, opmathtype);
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved

// Look for, at least, one tensor already in PyTorch/XLA.
InputVector::iterator it = std::find_if(
inputs.begin(), inputs.end(), [](const at::Tensor& tensor) {
return bridge::TryGetXlaTensor(tensor);
});
XLA_CHECK(it != inputs_.end());
// Transform the inputs into a list of XLATensorPtr.
// For that, either get their corresponding XLATensorPtr, or use the found
// XLA tensor's BackendDevice for creating a new one.
torch::lazy::BackendDevice device = bridge::GetXlaTensor(*it)->GetDevice();
XLAInputVector xla_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), xla_inputs.begin(),
[&](const at::Tensor& tensor) {
return bridge::GetOrCreateXlaTensor(tensor, device);
});

// Actually call the impl.
at::ScalarType inner_dtype =
(use_opmathtype_for_compute_) ? opmathtype : common_dtype;
XLATensorPtr xla_out = impl_(xla_inputs, inner_dtype);
at::Tensor out = bridge::AtenFromXlaTensor(xla_out);

// If we used OpMathType for the computation, cast the result back to its
// common_dtype.
if (use_opmathtype_for_compute_) {
out = out.to(common_dtype);
}

return out;
}

private:
// Pre-processes the inputs based on the state of this instance.
//
// In summary:
// - Cast the inputs to the common data-type (if
// cast_inputs_to_common_dtype_ is set)
//
// - Cast the inputs to the OpMathType data-type (if
// use_opmathtype_for_compute_ is set)
InputVector maybe_preprocess_inputs(at::ScalarType common_dtype,
at::ScalarType opmathtype) {
InputVector inputs = inputs_;

// Cast only once: either to the common dtype or to OpMathType.
if (use_opmathtype_for_compute_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(opmathtype); });
} else if (cast_inputs_to_common_dtype_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(common_dtype); });
}

return inputs;
}

// Actual implementation of the operation.
ImplFnType impl_;

// List of tensor inputs.
InputVector inputs_;

// Whether to cast every input to the common data-type.
// It's analogous to TensorIterator's flag. If the operation you are lowering
// uses TensorIterator in PyTorch, you can check whether to set this flag or
// not.
bool cast_inputs_to_common_dtype_ = false;

// Whether to use OpMathType for computation.
// This flag mimics the actual PyTorch kernel implementations. When lowering
// an operation, take a look at that for deciding whether to set this flag or
// not.
bool use_opmathtype_for_compute_ = false;
};

at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't
// have sizes/strides
Expand Down Expand Up @@ -2055,11 +2221,14 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
at::ScalarType dtype) {
return tensor_methods::mul(xself, xother, dtype);
});
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
c10::optional<at::ScalarType>);
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
.add_input(self)
.add_input(other)
.cast_inputs_to_common_dtype()
.use_opmathtype_for_compute()
.run();
}

at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,16 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor) -> std::string {
return GetXLATensorDebugInfo(tensor);
});
m.def("_get_xla_tensor_shape_type",
[](const at::Tensor& tensor) -> std::string {
XLATensorPtr xla_tensor = bridge::TryGetXlaTensor(tensor);
if (xla_tensor) {
xla::Shape shape = xla_tensor->shape().get();
return xla::primitive_util::LowercasePrimitiveTypeName(
shape.element_type());
}
});
JackCaoG marked this conversation as resolved.
Show resolved Hide resolved

py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>(
m, "XlaShardingSpec")
.def(py::init([](at::Tensor tensor, const py::list& tile_assignment,
Expand Down
Loading