Skip to content

Commit

Permalink
Add data-type promotion + OpMathType usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi committed May 29, 2024
1 parent dc7ef41 commit fdb724c
Showing 1 changed file with 174 additions and 5 deletions.
179 changes: 174 additions & 5 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/MetaFunctions.h>
#include <ATen/NativeFunctions.h>
#include <ATen/OpMathType.h>
#include <ATen/Operators.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/CPUFallback.h>
Expand Down Expand Up @@ -64,6 +65,171 @@
namespace torch_xla {
namespace {

using XLAInputVector = std::vector<XLATensorPtr>;

// Calls the inner function by spreading inputs in order, and adding the
// common data-type in the end.
template <class InnerFnType, size_t... Ints>
XLATensorPtr CallInner(const InnerFnType& inner, XLAInputVector inputs,
at::ScalarType common_dtype,
std::integer_sequence<size_t, Ints...> seq) {
return inner(inputs[Ints]..., common_dtype);
}

// Computes the number of XLATensorPtr arguments of a given function.
//
// This is used when calling tensor_methods functions, given a list of inputs.
// Specifically, in order to know how many inputs we should get from the list.
template <class T>
struct NumberOfXLATensorArgs {};

template <class... Args>
struct NumberOfXLATensorArgs<XLATensorPtr(Args...)> {
static constexpr size_t value =
(std::is_same_v<XLATensorPtr,
std::remove_cv_t<std::remove_reference_t<Args>>> +
...);
};

// Stateful configuration structure for pre/post-processing the inputs and the
// output.
//
// There are a few checks and preprocessing that PyTorch does, that we are
// mirroring with this class. This should help us get many data-type behavior
// right.
class OpConfig {
public:
using InputVector = std::vector<at::Tensor>;
using ImplFnType =
std::function<XLATensorPtr(const XLAInputVector&, at::ScalarType)>;

// Construct an instance from a function of exactly ImplFnType.
OpConfig(ImplFnType impl) : impl_(impl) {}

// Construct an instance from a function of the following type:
// XLATensorPtr(Tensor..., ScalarType)
//
// This is a convenience for wrapping tensor_methods functions.
template <class InnerFnType>
static OpConfig From(const InnerFnType& inner_impl) {
return OpConfig(
[&](const XLAInputVector& inputs, at::ScalarType common_dtype) {
constexpr size_t tensor_args =
NumberOfXLATensorArgs<std::remove_pointer_t<InnerFnType>>::value;
return CallInner(inner_impl, inputs, common_dtype,
std::make_index_sequence<tensor_args>{});
});
}

OpConfig& add_input(const at::Tensor& input) {
inputs_.push_back(input);
return *this;
}

OpConfig& cast_inputs_to_common_dtype() {
cast_inputs_to_common_dtype_ = true;
return *this;
}

OpConfig& use_opmathtype_for_compute() {
use_opmathtype_for_compute_ = true;
return *this;
}

// Pre-processes the inputs and post-processes the outputs depending on the
// configured state of this class.
//
// In summary, it will:
// - Compute the common data-type to be used
// - Cast the inputs to the common data-type
// - Cast the inputs to its OpMathType (for computation only)
// - Run the specified impl
// - Cast the output back to the common data-type
at::Tensor run() {
at::ScalarType common_dtype = at::native::result_type(inputs_);
at::ScalarType opmathtype = at::toOpMathType(common_dtype);

// Pre-process the inputs, given the specified configuration and
// common_dtype.
InputVector inputs = maybe_preprocess_inputs(common_dtype, opmathtype);

// Look for, at least, one tensor already in PyTorch/XLA.
InputVector::iterator it = std::find_if(
inputs.begin(), inputs.end(), [](const at::Tensor& tensor) {
return bridge::TryGetXlaTensor(tensor);
});
XLA_CHECK(it != inputs_.end());
// Transform the inputs into a list of XLATensorPtr.
// For that, either get their corresponding XLATensorPtr, or use the found
// XLA tensor's BackendDevice for creating a new one.
torch::lazy::BackendDevice device = bridge::GetXlaTensor(*it)->GetDevice();
XLAInputVector xla_inputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), xla_inputs.begin(),
[&](const at::Tensor& tensor) {
return bridge::GetOrCreateXlaTensor(tensor, device);
});

// Actually call the impl.
at::ScalarType inner_dtype =
(use_opmathtype_for_compute_) ? opmathtype : common_dtype;
XLATensorPtr xla_out = impl_(xla_inputs, inner_dtype);
at::Tensor out = bridge::AtenFromXlaTensor(xla_out);

// If we used OpMathType for the computation, cast the result back to its
// common_dtype.
if (use_opmathtype_for_compute_) {
out = out.to(common_dtype);
}

return out;
}

private:
// Pre-processes the inputs based on the state of this instance.
//
// In summary:
// - Cast the inputs to the common data-type (if
// cast_inputs_to_common_dtype_ is set)
//
// - Cast the inputs to the OpMathType data-type (if
// use_opmathtype_for_compute_ is set)
InputVector maybe_preprocess_inputs(at::ScalarType common_dtype,
at::ScalarType opmathtype) {
InputVector inputs = inputs_;

// Cast only once: either to the common dtype or to OpMathType.
if (use_opmathtype_for_compute_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(opmathtype); });
} else if (cast_inputs_to_common_dtype_) {
std::transform(
inputs.begin(), inputs.end(), inputs.begin(),
[=](const at::Tensor& tensor) { return tensor.to(common_dtype); });
}

return inputs;
}

// Actual implementation of the operation.
ImplFnType impl_;

// List of tensor inputs.
InputVector inputs_;

// Whether to cast every input to the common data-type.
// It's analogous to TensorIterator's flag. If the operation you are lowering
// uses TensorIterator in PyTorch, you can check whether to set this flag or
// not.
bool cast_inputs_to_common_dtype_ = false;

// Whether to use OpMathType for computation.
// This flag mimics the actual PyTorch kernel implementations. When lowering
// an operation, take a look at that for deciding whether to set this flag or
// not.
bool use_opmathtype_for_compute_ = false;
};

at::Tensor to_meta(const at::Tensor& tensor) {
// undefined tensors can't be converted to the meta device, since they don't
// have sizes/strides
Expand Down Expand Up @@ -2055,11 +2221,14 @@ at::Tensor XLANativeFunctions::mse_loss_backward(const at::Tensor& grad_output,
at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
const at::Tensor& other) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
return DoBinaryOp(self, other,
[&](const XLATensorPtr& xself, const XLATensorPtr& xother,
at::ScalarType dtype) {
return tensor_methods::mul(xself, xother, dtype);
});
using FnType = XLATensorPtr(const XLATensorPtr&, const XLATensorPtr&,
c10::optional<at::ScalarType>);
return OpConfig::From(static_cast<FnType*>(tensor_methods::mul))
.add_input(self)
.add_input(other)
.cast_inputs_to_common_dtype()
.use_opmathtype_for_compute()
.run();
}

at::Tensor XLANativeFunctions::mul(const at::Tensor& self,
Expand Down

0 comments on commit fdb724c

Please sign in to comment.