Skip to content

Commit

Permalink
mul: convert inputs to result type. (#7130)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Jun 3, 2024
1 parent 89832c3 commit 7938bb5
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 5 deletions.
36 changes: 36 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,6 +2082,42 @@ def test(f, xshape, ishapes):
for xshape, i0shape, i1shape in cases[f2]:
test(f2, xshape, (i0shape, i1shape))

def test_inplace_mul_scalar_different_dtype(self):
# This tests whether the returned output data-type agrees on PyTorch
# and XLA sides.
#
# Technical details: even though we were computing the common data-type
# inside PyTorch/XLA XLANativeFunctions::mul function, we were using it
# just for telling PyTorch what the output data-type would be, i.e. creating
# an IR node of that data-type). Meanwhile, in the XLA side of things,
# it would just promote the tensors using other data-type promotion rules.
#
# In summary, given the expressions below, the problem this test covers is:
#
# >>> t = torch.rand(10, dtype=torch.half)
# >>> s = torch.tensor(5, dtype=torch.double)
# >>> out = t.mul_(s)
#
# out.dtype is torch.float16, but its underlying XLA type (xla::Shape's
# element_type) is F64
#
# See: https://github.com/pytorch/xla/issues/7084

def fn(inp, s):
return inp.mul_(s)

inp = torch.rand(10, dtype=torch.half)
s = torch.tensor(7, dtype=torch.double)

Xinp = inp.to(xm.xla_device())
Xs = s.to(xm.xla_device())

out = fn(inp, s)
Xout = fn(Xinp, Xs)

self.assertEqual(out, Xout.cpu())
self.assertEqual("f16", torch_xla._XLAC._get_xla_tensor_shape_type(Xout))


class MNISTComparator(nn.Module):

Expand Down
179 changes: 174 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/MetaFunctions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/OpMathType.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
Expand Down Expand Up @@ -64,6 +65,171 @@
namespace torch_xla {
namespace {

using XLAInputVector = std::vector<XLATensorPtr>;

// Calls the inner function by spreading inputs in order, and adding the
// common data-type in the end.
template <class InnerFnType, size_t... Ints>
XLATensorPtr CallInner(const InnerFnType& inner, XLAInputVector inputs,
at::ScalarType common_dtype,
std::integer_sequence<size_t, Ints...> seq) {
return inner(inputs[Ints]..., common_dtype);
}

// Computes the number of XLATensorPtr arguments of a given function.
//
// This is used when calling tensor_methods functions, given a list of inputs.
// Specifically, in order to know how many inputs we should get from the list.
template <class T>
struct NumberOfXLATensorArgs {};

template <class... Args>
struct NumberOfXLATensorArgs<XLATensorPtr(Args...)> {
static constexpr size_t value =
(std::is_same_v<XLATensorPtr,
std::remove_cv_t<std::remove_reference_t<Args>>> +
...);
};

// Stateful configuration structure for pre/post-processing the inputs and the
// output.
//
// There are a few checks and preprocessing that PyTorch does, that we are
// mirroring with this class. This should help us get many data-type behavior
// right.
class OpConfig {
public:
using InputVector = std::vector<at::Tensor>;
using ImplFnType =
std::function<XLATensorPtr(const XLAInputVector&, at::ScalarType)>;

// Construct an instance from a function of exactly ImplFnType.
OpConfig(ImplFnType impl) : impl_(impl) {}

// Construct an instance from a function of the following type:
// XLATensorPtr(Tensor..., ScalarType)
//
// This is a convenience for wrapping tensor_methods functions.
template <class InnerFnType>
static OpConfig From(const InnerFnType& inner_impl) {
return OpConfig(
[&](const XLAInputVector& inputs, at::ScalarType common_dtype) {
constexpr size_t num_tensor_args =
NumberOfXLATensorArgs<std::remove_pointer_t<InnerFnType>>::value;
return CallInner(inner_impl, inputs, common_dtype,
std::make_index_sequence<num_tensor_args>{});
});
}

OpConfig& add_input(const at::Tensor& input) {
inputs_.push_back(input);
return *this;
}

OpConfig& cast_inputs_to_common_dtype() {
cast_inputs_to_common_dtype_ = true;
return *this;
}

OpConfig& use_opmathtype_for_compute() {
use_opmathtype_for_compute_ = true;
return *this;
}

// Pre-processes the inputs and post-processes the outputs depending on the
// configured state of this class.
//
// In summary, it will:
// - Compute the common data-type to be used
// - Cast the inputs to the common data-type
// - Cast the inputs to its OpMathType (for computation only)
// - Run the specified impl
// - Cast the output back to the common data-type
at::Tensor run() {
at::ScalarType common_dtype = at::native::result_type(inputs_);
at::ScalarType opmathtype = at::toOpMathType(common_dtype);

// Pre-process the inputs, given the specified configuration and
// common_dtype.
InputVector inputs = maybe_preprocess_inputs(common_dtype, opmathtype);

// Look for, at least, one tensor already in PyTorch/XLA.
InputVector::iterator it = std::find_if(
inputs.begin(), inputs.end(), [](const at::Tensor& tensor) {
return bridge::TryGetXlaTensor(tensor);
});
XLA_CHECK(it != inputs_.end());
// Transform the inputs into a list of XLATensorPtr.
// For that, either get their corresponding XLATensorPtr, or use the found
// XLA tensor's BackendDevice for creating a new one.
torch::lazy::BackendDevice device = bridge::GetXlaTensor(*it)->GetDevice();
XLAInputVector xla_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), xla_inputs.begin(),
[&](const at::Tensor& tensor) {
return bridge::GetOrCreateXlaTensor(tensor, device);
});

// Actually call the impl.
at::ScalarType inner_dtype =
(use_opmathtype_for_compute_) ? opmathtype : common_dtype;
XLATensorPtr xla_out = impl_(xla_inputs, inner_dtype);
at::Tensor out = bridge::AtenFromXlaTensor(xla_out);

// If we used OpMathType for the computation, cast the result back to its
// common_dtype.
if (use_opmathtype_for_compute_) {
out = out.to(common_dtype);
}

return out;
}

private:
// Pre-processes the inputs based on the state of this instance.
//
// In summary:
// - Cast the inputs to the common data-type (if
// cast_inputs_to_common_dtype_ is set)
//
// - Cast the inputs to the OpMathType data-type (if
// use_opmathtype_for_compute_ is set)
InputVector maybe_preprocess_inputs(at::ScalarType common_dtype,
at::ScalarType opmathtype) {
InputVector inputs = inputs_;

// Cast only once: either to the common dtype or to OpMathType.
if (use_opmathtype_for_compute_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(opmathtype); });
} else if (cast_inputs_to_common_dtype_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(common_dtype); });
}

return inputs;
}

// Actual implementation of the operation.
ImplFnType impl_;

// List of tensor inputs.
InputVector inputs_;

// Whether to cast every input to the common data-type.
// It's analogous to TensorIterator's flag. If the operation you are lowering
// uses TensorIterator in PyTorch, you can check whether to set this flag or
// not.
bool cast_inputs_to_common_dtype_ = false;

// Whether to use OpMathType for computation.
// This flag mimics the actual PyTorch kernel implementations. When lowering
// an operation, take a look at that for deciding whether to set this flag or
// not.
bool use_opmathtype_for_compute_ = false;
};

at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't
// have sizes/strides
Expand Down Expand Up @@ -2055,11 +2221,14 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
at::ScalarType dtype) {
return tensor_methods::mul(xself, xother, dtype);
});
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
c10::optional<at::ScalarType>);
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
.add_input(self)
.add_input(other)
.cast_inputs_to_common_dtype()
.use_opmathtype_for_compute()
.run();
}

at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,16 @@ void InitXlaModuleBindings(py::module m) {
[](const at::Tensor& tensor) -> std::string {
return GetXLATensorDebugInfo(tensor);
});
m.def("_get_xla_tensor_shape_type",
[](const at::Tensor& tensor) -> std::string {
XLATensorPtr xla_tensor = bridge::TryGetXlaTensor(tensor);
if (xla_tensor) {
xla::Shape shape = xla_tensor->shape().get();
return xla::primitive_util::LowercasePrimitiveTypeName(
shape.element_type());
}
});

py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>(
m, "XlaShardingSpec")
.def(py::init([](at::Tensor tensor, const py::list& tile_assignment,
Expand Down

0 comments on commit 7938bb5

Please sign in to comment.