From 4513eb3d4e8fe713abf361b61c12f6a5889e19b2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 18:41:42 +0000 Subject: [PATCH 01/10] Remove obsolete device special cases --- torch_xla/csrc/convert_ops.cpp | 62 +++++++++++++++++----------------- torch_xla/csrc/dtype.cpp | 28 +++++++-------- 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index a920bdb69e9..e841cc8129c 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -15,10 +15,10 @@ namespace torch_xla { namespace { -xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { - xla::XlaOp zero = xla::Zero(op.builder(), from); - return xla::Ne(op, zero); -} +// xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { +// xla::XlaOp zero = xla::Zero(op.builder(), from); +// return xla::Ne(op, zero); +// } xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size, int64_t narrow_size) { @@ -60,33 +60,33 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, } XlaDeviceType hw_type = static_cast(bridge::GetDeviceOrCurrent(device).type()); - if (hw_type != XlaDeviceType::TPU) { - return xla::ConvertElementType(op, to); - } - switch (from) { - case xla::PrimitiveType::PRED: - case xla::PrimitiveType::S8: - case xla::PrimitiveType::U8: - case xla::PrimitiveType::S16: - case xla::PrimitiveType::U16: - case xla::PrimitiveType::S32: - case xla::PrimitiveType::U32: - case xla::PrimitiveType::BF16: - case xla::PrimitiveType::F32: - return xla::ConvertElementType(op, to); - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U64: { - switch (to) { - case xla::PrimitiveType::PRED: - return ExplicitBooleanConvert(op, from); - default: - return xla::ConvertElementType(op, to); - } - break; - } - default: - XLA_ERROR() << "Unsupported XLA type " << from; - } + // if (hw_type != XlaDeviceType::TPU) { + return xla::ConvertElementType(op, to); + // } + // switch (from) { + // case xla::PrimitiveType::PRED: + // case xla::PrimitiveType::S8: + // case xla::PrimitiveType::U8: + // case xla::PrimitiveType::S16: + // case xla::PrimitiveType::U16: + // case xla::PrimitiveType::S32: + // case xla::PrimitiveType::U32: + // case xla::PrimitiveType::BF16: + // case xla::PrimitiveType::F32: + // return xla::ConvertElementType(op, to); + // case xla::PrimitiveType::S64: + // case xla::PrimitiveType::U64: { + // switch (to) { + // case xla::PrimitiveType::PRED: + // return ExplicitBooleanConvert(op, from); + // default: + // return xla::ConvertElementType(op, to); + // } + // break; + // } + // default: + // XLA_ERROR() << "Unsupported XLA type " << from; + // } } xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 103630bcec8..7f2d73401f5 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -75,15 +75,15 @@ bool Use32BitLong() { return use_32bit_long; } -bool IsTpuDevice(XlaDeviceType hw_type) { - static bool spmd_device_is_tpu = - (hw_type == XlaDeviceType::SPMD) && - // HACK: find a better way to decide if SPMD is actually a TPU without - // accessing the runtime. - runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != - std::string::npos; - return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; -} +// bool IsTpuDevice(XlaDeviceType hw_type) { +// static bool spmd_device_is_tpu = +// (hw_type == XlaDeviceType::SPMD) && +// // HACK: find a better way to decide if SPMD is actually a TPU without +// // accessing the runtime. +// runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != +// std::string::npos; +// return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; +// } } // namespace @@ -163,8 +163,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( if (UseBF16()) { return xla::PrimitiveType::BF16; } - if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) || - hw_type == XlaDeviceType::NEURON) { + if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) { return xla::PrimitiveType::F32; } return xla::PrimitiveType::F64; @@ -175,11 +174,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16 : xla::PrimitiveType::U32; case xla::PrimitiveType::S16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16 : xla::PrimitiveType::S32; case xla::PrimitiveType::S64: @@ -187,8 +186,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( case xla::PrimitiveType::U64: return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: - return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 - : xla::PrimitiveType::C64; + return xla::PrimitiveType::C128; default: return type; } From c438e3ad690df80e46dcc5494e7625463cb29881 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 21:07:20 +0000 Subject: [PATCH 02/10] remove comments --- torch_xla/csrc/convert_ops.cpp | 33 --------------------------------- torch_xla/csrc/dtype.cpp | 10 ---------- 2 files changed, 43 deletions(-) diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index e841cc8129c..af381c36a50 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -15,11 +15,6 @@ namespace torch_xla { namespace { -// xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { -// xla::XlaOp zero = xla::Zero(op.builder(), from); -// return xla::Ne(op, zero); -// } - xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size, int64_t narrow_size) { uint64_t mask_value = @@ -58,35 +53,7 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, if (from == to) { return op; } - XlaDeviceType hw_type = - static_cast(bridge::GetDeviceOrCurrent(device).type()); - // if (hw_type != XlaDeviceType::TPU) { return xla::ConvertElementType(op, to); - // } - // switch (from) { - // case xla::PrimitiveType::PRED: - // case xla::PrimitiveType::S8: - // case xla::PrimitiveType::U8: - // case xla::PrimitiveType::S16: - // case xla::PrimitiveType::U16: - // case xla::PrimitiveType::S32: - // case xla::PrimitiveType::U32: - // case xla::PrimitiveType::BF16: - // case xla::PrimitiveType::F32: - // return xla::ConvertElementType(op, to); - // case xla::PrimitiveType::S64: - // case xla::PrimitiveType::U64: { - // switch (to) { - // case xla::PrimitiveType::PRED: - // return ExplicitBooleanConvert(op, from); - // default: - // return xla::ConvertElementType(op, to); - // } - // break; - // } - // default: - // XLA_ERROR() << "Unsupported XLA type " << from; - // } } xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 7f2d73401f5..01510ce835d 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -75,16 +75,6 @@ bool Use32BitLong() { return use_32bit_long; } -// bool IsTpuDevice(XlaDeviceType hw_type) { -// static bool spmd_device_is_tpu = -// (hw_type == XlaDeviceType::SPMD) && -// // HACK: find a better way to decide if SPMD is actually a TPU without -// // accessing the runtime. -// runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != -// std::string::npos; -// return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; -// } - } // namespace at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { From 825686bd91f2478aee4321b2d967b23eabbe28ec Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 21:16:11 +0000 Subject: [PATCH 03/10] remove device from ConvertTo --- torch_xla/csrc/convert_ops.cpp | 10 ++++------ torch_xla/csrc/convert_ops.h | 5 ++--- torch_xla/csrc/data_ops.cpp | 4 ++-- torch_xla/csrc/helpers.cpp | 14 +++++++------- torch_xla/csrc/matrix.cpp | 2 +- torch_xla/csrc/ops/ops.cpp | 8 +++----- torch_xla/csrc/ops/ops_lower_fn.cpp | 6 ++---- torch_xla/csrc/ops/prod.cpp | 3 +-- torch_xla/csrc/xla_lower_util.cpp | 16 ++++++---------- 9 files changed, 28 insertions(+), 40 deletions(-) diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index af381c36a50..3dda62ae004 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -48,8 +48,7 @@ xla::XlaOp ConvertData(xla::XlaOp op, xla::PrimitiveType type, } // namespace xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType to) { if (from == to) { return op; } @@ -63,7 +62,7 @@ xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, if (from != raw_from) { op = ConvertData(op, from, raw_from); } - xla::XlaOp result = ConvertTo(op, from, to, device); + xla::XlaOp result = ConvertTo(op, from, to); return to == raw_to ? result : ConvertData(result, to, raw_to); } @@ -72,8 +71,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); op = ConvertTo( op, from, - MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device), - &xla_device); + MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device)); } return op; } @@ -87,7 +85,7 @@ xla::XlaOp CastToScalarType(xla::XlaOp input, if (dtype) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, &xla_device), &xla_device); + MakeXlaPrimitiveType(*dtype, &xla_device)); } return ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/convert_ops.h b/torch_xla/csrc/convert_ops.h index 3dd0ce99f3a..e9f8b6682d2 100644 --- a/torch_xla/csrc/convert_ops.h +++ b/torch_xla/csrc/convert_ops.h @@ -11,8 +11,7 @@ namespace torch_xla { xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType to); xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, @@ -32,4 +31,4 @@ xla::XlaOp MaybeConvertTo(xla::XlaOp input, xla::PrimitiveType type); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index bb02ca7da2e..3eb04baa8e8 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -160,7 +160,7 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp mask_pred = xla::Ne(mask, zero); xla::XlaOp update_scalar = ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(), - ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr); + ShapeHelper::ShapeOfXlaOp(input).element_type()); return xla::Select(mask_pred, update_scalar, input); } @@ -291,7 +291,7 @@ xla::XlaOp BuildUpdateSlice(xla::XlaOp input, xla::XlaOp source, xla::XlaOp update_source = source; if (source_shape.element_type() != input_shape.element_type()) { update_source = ConvertTo(source, source_shape.element_type(), - input_shape.element_type(), /*device=*/nullptr); + input_shape.element_type()); } xla::XlaOp reshaped_source = XlaHelpers::ReshapeToRank(update_source, input_shape.rank()); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 1e4f3be9f78..af04e706a74 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -29,7 +29,7 @@ xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2); xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result); if (type1 == type2 && type1 != result_type) { - return ConvertTo(result, result_type, type1, /*device=*/nullptr); + return ConvertTo(result, result_type, type1); } return result; } @@ -489,10 +489,10 @@ std::pair XlaHelpers::PromoteValues(xla::XlaOp op1, xla::PrimitiveType type2 = TypeOfXlaOp(op2); xla::PrimitiveType result_type = PromoteType(type1, type2); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } return std::pair(op1, op2); } @@ -504,13 +504,13 @@ std::tuple XlaHelpers::PromoteValues( xla::PrimitiveType type3 = TypeOfXlaOp(op3); xla::PrimitiveType result_type = PromoteType(type1, type2, type3); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } if (type3 != result_type) { - op3 = ConvertTo(op3, type3, result_type, /*device=*/nullptr); + op3 = ConvertTo(op3, type3, result_type); } return std::tuple(op1, op2, op3); } @@ -522,7 +522,7 @@ std::pair XlaHelpers::PromoteSecondValue( return type1 == type2 ? std::pair(op1, op2) : std::pair( - op1, ConvertTo(op2, type2, type1, /*device=*/nullptr)); + op1, ConvertTo(op2, type2, type1)); } xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1, diff --git a/torch_xla/csrc/matrix.cpp b/torch_xla/csrc/matrix.cpp index 0cffa5c8d2a..eccfc759a3d 100644 --- a/torch_xla/csrc/matrix.cpp +++ b/torch_xla/csrc/matrix.cpp @@ -110,7 +110,7 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, xla::XlaOp diag_input = input; if (target_shape->element_type() != input_shape.element_type()) { diag_input = ConvertTo(input, input_shape.element_type(), - target_shape->element_type(), /*device=*/nullptr); + target_shape->element_type()); } std::vector permutation; xla::XlaOp diag_target = target; diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 494f79c2b09..0e1e885dde4 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -262,10 +262,8 @@ torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type, - /*device=*/nullptr); - xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type, - /*device=*/nullptr); + xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type); + xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type); return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, @@ -412,7 +410,7 @@ torch::lazy::NodePtr Where(const torch::lazy::Value& condition, xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(2)); xla::XlaOp pred_condition = ConvertTo(xla_condition, XlaHelpers::TypeOfXlaOp(xla_condition), - xla::PrimitiveType::PRED, /*device=*/nullptr); + xla::PrimitiveType::PRED); auto promoted_branches = XlaHelpers::ValidateShapes(xla_input, xla_other); return node.ReturnOp(xla::Select(pred_condition, promoted_branches.first, promoted_branches.second), diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 99250adbdad..54bd11a553f 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -743,8 +743,7 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tan(xla_input), loctx); } @@ -753,8 +752,7 @@ torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tanh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index e61bc9541ea..1790d477cce 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -20,8 +20,7 @@ xla::XlaOp LowerProd(xla::XlaOp input, const std::vector& dimensions, xla::XlaOp casted_input; if (dtype) { casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, /*device=*/nullptr), - /*device=*/nullptr); + MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); } else { casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 34722987954..29230546844 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -170,7 +170,7 @@ xla::XlaOp CreateIndexAlongDim( xla::XlaOp updates = value; if (buffer_shape.element_type() != value_shape.element_type()) { updates = ConvertTo(updates, value_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } if (broadcast_value_to_index) { const xla::Shape& index_shape = ShapeHelper::ShapeOfXlaOp(index); @@ -603,7 +603,7 @@ xla::XlaOp CreateIndexUpdate( xla::XlaOp new_values = values; if (buffer_shape.element_type() != values_shape.element_type()) { new_values = ConvertTo(new_values, values_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } new_values = BuildExpand(new_values, expected_values_dims); const xla::Shape& new_values_shape = ShapeHelper::ShapeOfXlaOp(new_values); @@ -654,8 +654,7 @@ XlaOpCombiner NumericAddCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x + numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -665,8 +664,7 @@ XlaOpCombiner NumericMulCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x * numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -677,8 +675,7 @@ XlaOpCombiner NumericMinCombiner() { xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); // xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -688,8 +685,7 @@ XlaOpCombiner NumericMaxCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = xla::Max(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } From 95cf14cabf3dbd3a93fd828999b57fc6e27907d2 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 21:17:23 +0000 Subject: [PATCH 04/10] remove device from `ConvertToRaw` --- torch_xla/csrc/convert_ops.cpp | 3 +-- torch_xla/csrc/convert_ops.h | 3 +-- torch_xla/csrc/ops/cast.cpp | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index 3dda62ae004..cd86e0f3169 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -57,8 +57,7 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType raw_to) { if (from != raw_from) { op = ConvertData(op, from, raw_from); } diff --git a/torch_xla/csrc/convert_ops.h b/torch_xla/csrc/convert_ops.h index e9f8b6682d2..029599667bd 100644 --- a/torch_xla/csrc/convert_ops.h +++ b/torch_xla/csrc/convert_ops.h @@ -15,8 +15,7 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType raw_to); xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from); diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 95068640a27..f1a0a1a9072 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -55,8 +55,7 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const { stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type(); xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_; xla::XlaOp output = - ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to, - /*device=*/nullptr); + ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to); return ReturnOp(output, loctx); } From 76d825f019c6131e317cdc125d953634e3a5a735 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Mon, 18 Dec 2023 21:18:10 +0000 Subject: [PATCH 05/10] formatting --- torch_xla/csrc/dtype.cpp | 10 ++++------ torch_xla/csrc/helpers.cpp | 7 +++---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 01510ce835d..d4ed1b413a6 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -164,13 +164,11 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16 + : xla::PrimitiveType::U32; case xla::PrimitiveType::S16: - return hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16 + : xla::PrimitiveType::S32; case xla::PrimitiveType::S64: return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index af04e706a74..895b9f9279e 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -519,10 +519,9 @@ std::pair XlaHelpers::PromoteSecondValue( xla::XlaOp op1, xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); - return type1 == type2 - ? std::pair(op1, op2) - : std::pair( - op1, ConvertTo(op2, type2, type1)); + return type1 == type2 ? std::pair(op1, op2) + : std::pair( + op1, ConvertTo(op2, type2, type1)); } xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1, From 672059462a7a7587cd66662edfb143c994211a7d Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Jan 2024 19:13:46 +0000 Subject: [PATCH 06/10] Add test_dtypes --- test/pjrt/test_dtypes.py | 43 ++++++++++++++++++++++++++++++++++++++ test/run_tests.sh | 3 ++- test/tpu/xla_test_job.yaml | 1 + 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 test/pjrt/test_dtypes.py diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py new file mode 100644 index 00000000000..f8d1f4955cb --- /dev/null +++ b/test/pjrt/test_dtypes.py @@ -0,0 +1,43 @@ +from absl.testing import absltest, parameterized +import torch_xla.runtime as xr +import torch_xla.core.xla_model as xm +import torch + +class TestDtypes(parameterized.TestCase): + + def setUp(self): + xr.set_device_type('TPU') + + @parameterized.parameters( + torch.float16, + torch.float32, + torch.float64, + torch.bfloat16, + torch.complex64, + torch.complex128) + def test_float_round_trip(self, dtype: torch.dtype): + t = torch.randn((3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + + @parameterized.parameters( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) + def test_int_round_trip(self, dtype: torch.dtype): + t = torch.randint(0, 128, (3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + def test_bool_round_trip(self): + t = torch.randint(0, 2, (3, 3), dtype=torch.bool) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/run_tests.sh b/test/run_tests.sh index 4e5dc6e90f6..1553d53e409 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -128,7 +128,7 @@ function run_torchrun { echo "Running torchrun test for GPU $@" num_devices=$(nvidia-smi --list-gpus | wc -l) PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@ - fi + fi } function run_torch_op_tests { @@ -190,6 +190,7 @@ function run_xla_op_tests1 { # DO NOT MODIFY function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" + run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index e727953ddc4..b0d163ce6f3 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -57,6 +57,7 @@ spec: python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py + python3 /src/pytorch/xla/test/pjrt/test_dtypes.py volumeMounts: - mountPath: /dev/shm name: dshm From 92e9a2c0a360cdd7351736dc41d92d0d5c678526 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Jan 2024 19:19:57 +0000 Subject: [PATCH 07/10] formatting --- test/pjrt/test_dtypes.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index f8d1f4955cb..485f6b26b72 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -1,26 +1,21 @@ from absl.testing import absltest, parameterized -import torch_xla.runtime as xr -import torch_xla.core.xla_model as xm import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + class TestDtypes(parameterized.TestCase): def setUp(self): xr.set_device_type('TPU') - @parameterized.parameters( - torch.float16, - torch.float32, - torch.float64, - torch.bfloat16, - torch.complex64, - torch.complex128) + @parameterized.parameters(torch.float16, torch.float32, torch.float64, + torch.bfloat16, torch.complex64, torch.complex128) def test_float_round_trip(self, dtype: torch.dtype): t = torch.randn((3, 3), dtype=dtype) xt = t.to(xm.xla_device()) torch.testing.assert_close(xt.cpu(), t) - @parameterized.parameters( torch.uint8, torch.int8, From 5421d64d68feae7d3d0900ca648e546051082528 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Jan 2024 19:32:38 +0000 Subject: [PATCH 08/10] remove complex128 test --- test/pjrt/test_dtypes.py | 2 +- torch_xla/csrc/ops/ops_lower_fn.cpp | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index 485f6b26b72..14f1eab8935 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -10,7 +10,7 @@ def setUp(self): xr.set_device_type('TPU') @parameterized.parameters(torch.float16, torch.float32, torch.float64, - torch.bfloat16, torch.complex64, torch.complex128) + torch.bfloat16, torch.complex64) def test_float_round_trip(self, dtype: torch.dtype): t = torch.randn((3, 3), dtype=dtype) xt = t.to(xm.xla_device()) diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 2b9104192d7..e45c7782eb5 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -25,8 +25,7 @@ torch_xla::XlaOpVector Acos::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Acos(xla_input), loctx); } @@ -739,8 +738,7 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Sinh(xla_input), loctx); } @@ -769,8 +767,7 @@ torch_xla::XlaOpVector Sqrt::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Sqrt(xla_input), loctx); } From bb04b655744e70cbfe259cb63dd82db72fff7d59 Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 2 Jan 2024 21:28:34 +0000 Subject: [PATCH 09/10] Generalize `test_dtypes` --- test/pjrt/test_dtypes.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index 14f1eab8935..4279d65a137 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -3,15 +3,20 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr +unsupported_dtypes_per_device = { + 'TPU': [torch.complex128,], +} -class TestDtypes(parameterized.TestCase): - def setUp(self): - xr.set_device_type('TPU') +class TestDtypes(parameterized.TestCase): @parameterized.parameters(torch.float16, torch.float32, torch.float64, - torch.bfloat16, torch.complex64) + torch.bfloat16, torch.complex64, torch.complex128) def test_float_round_trip(self, dtype: torch.dtype): + unsupported_dtypes = unsupported_dtypes_per_device.get(xr.device_type(), []) + if dtype in unsupported_dtypes: + self.skipTest(f'Unsupported dtype: {dtype}') + t = torch.randn((3, 3), dtype=dtype) xt = t.to(xm.xla_device()) torch.testing.assert_close(xt.cpu(), t) From 37fff0f60c94879a61b5abf0d67170627fc450cd Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Wed, 3 Jan 2024 00:25:41 +0000 Subject: [PATCH 10/10] remove complex128 test again --- test/pjrt/test_dtypes.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py index 4279d65a137..ebac882efdf 100644 --- a/test/pjrt/test_dtypes.py +++ b/test/pjrt/test_dtypes.py @@ -3,20 +3,12 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr -unsupported_dtypes_per_device = { - 'TPU': [torch.complex128,], -} - class TestDtypes(parameterized.TestCase): @parameterized.parameters(torch.float16, torch.float32, torch.float64, - torch.bfloat16, torch.complex64, torch.complex128) + torch.bfloat16, torch.complex64) def test_float_round_trip(self, dtype: torch.dtype): - unsupported_dtypes = unsupported_dtypes_per_device.get(xr.device_type(), []) - if dtype in unsupported_dtypes: - self.skipTest(f'Unsupported dtype: {dtype}') - t = torch.randn((3, 3), dtype=dtype) xt = t.to(xm.xla_device()) torch.testing.assert_close(xt.cpu(), t)