Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove obsolete device special cases #6197

Merged
merged 11 commits into from
Jan 3, 2024
35 changes: 35 additions & 0 deletions test/pjrt/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from absl.testing import absltest, parameterized
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


class TestDtypes(parameterized.TestCase):

@parameterized.parameters(torch.float16, torch.float32, torch.float64,
torch.bfloat16, torch.complex64)
def test_float_round_trip(self, dtype: torch.dtype):
t = torch.randn((3, 3), dtype=dtype)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)

@parameterized.parameters(
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
def test_int_round_trip(self, dtype: torch.dtype):
t = torch.randint(0, 128, (3, 3), dtype=dtype)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)

def test_bool_round_trip(self):
t = torch.randint(0, 2, (3, 3), dtype=torch.bool)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)


if __name__ == "__main__":
absltest.main()
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function run_torchrun {
echo "Running torchrun test for GPU $@"
num_devices=$(nvidia-smi --list-gpus | wc -l)
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@
fi
fi
}

function run_torch_op_tests {
Expand Down Expand Up @@ -190,6 +190,7 @@ function run_xla_op_tests1 {
# DO NOT MODIFY
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
python3 /src/pytorch/xla/test/test_autocast.py
python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dtypes.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
volumeMounts:
- mountPath: /dev/shm
Expand Down
48 changes: 6 additions & 42 deletions torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
namespace torch_xla {
namespace {

xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) {
xla::XlaOp zero = xla::Zero(op.builder(), from);
return xla::Ne(op, zero);
}

xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size,
int64_t narrow_size) {
uint64_t mask_value =
Expand Down Expand Up @@ -53,50 +48,20 @@ xla::XlaOp ConvertData(xla::XlaOp op, xla::PrimitiveType type,
} // namespace

xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to,
const torch::lazy::BackendDevice* device) {
xla::PrimitiveType to) {
if (from == to) {
return op;
}
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDeviceOrCurrent(device).type());
if (hw_type != XlaDeviceType::TPU) {
return xla::ConvertElementType(op, to);
}
switch (from) {
case xla::PrimitiveType::PRED:
case xla::PrimitiveType::S8:
case xla::PrimitiveType::U8:
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::BF16:
case xla::PrimitiveType::F32:
return xla::ConvertElementType(op, to);
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64: {
switch (to) {
case xla::PrimitiveType::PRED:
return ExplicitBooleanConvert(op, from);
default:
return xla::ConvertElementType(op, to);
}
break;
}
default:
XLA_ERROR() << "Unsupported XLA type " << from;
}
return xla::ConvertElementType(op, to);
}

xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType raw_from, xla::PrimitiveType to,
xla::PrimitiveType raw_to,
const torch::lazy::BackendDevice* device) {
xla::PrimitiveType raw_to) {
if (from != raw_from) {
op = ConvertData(op, from, raw_from);
}
xla::XlaOp result = ConvertTo(op, from, to, device);
xla::XlaOp result = ConvertTo(op, from, to);
return to == raw_to ? result : ConvertData(result, to, raw_to);
}

Expand All @@ -105,8 +70,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
op = ConvertTo(
op, from,
MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device),
&xla_device);
MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device));
}
return op;
}
Expand All @@ -120,7 +84,7 @@ xla::XlaOp CastToScalarType(xla::XlaOp input,
if (dtype) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
MakeXlaPrimitiveType(*dtype, &xla_device), &xla_device);
MakeXlaPrimitiveType(*dtype, &xla_device));
}
return ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input));
}
Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/convert_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
namespace torch_xla {

xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to,
const torch::lazy::BackendDevice* device);
xla::PrimitiveType to);

xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType raw_from, xla::PrimitiveType to,
xla::PrimitiveType raw_to,
const torch::lazy::BackendDevice* device);
xla::PrimitiveType raw_to);

xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from);

Expand All @@ -32,4 +30,4 @@ xla::XlaOp MaybeConvertTo(xla::XlaOp input, xla::PrimitiveType type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_
#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_
4 changes: 2 additions & 2 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
xla::XlaOp mask_pred = xla::Ne(mask, zero);
xla::XlaOp update_scalar =
ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(),
ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr);
ShapeHelper::ShapeOfXlaOp(input).element_type());
return xla::Select(mask_pred, update_scalar, input);
}

Expand Down Expand Up @@ -291,7 +291,7 @@ xla::XlaOp BuildUpdateSlice(xla::XlaOp input, xla::XlaOp source,
xla::XlaOp update_source = source;
if (source_shape.element_type() != input_shape.element_type()) {
update_source = ConvertTo(source, source_shape.element_type(),
input_shape.element_type(), /*device=*/nullptr);
input_shape.element_type());
}
xla::XlaOp reshaped_source =
XlaHelpers::ReshapeToRank(update_source, input_shape.rank());
Expand Down
26 changes: 6 additions & 20 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,6 @@ bool Use32BitLong() {
return use_32bit_long;
}

bool IsTpuDevice(XlaDeviceType hw_type) {
static bool spmd_device_is_tpu =
(hw_type == XlaDeviceType::SPMD) &&
// HACK: find a better way to decide if SPMD is actually a TPU without
// accessing the runtime.
runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") !=
std::string::npos;
return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) ||
hw_type == XlaDeviceType::NEURON) {
if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeffhataws do you know what dtype that nuron device does not support?

@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@will-cromar in the long term I think we might want to come up with a mechanism for each backend to register what kind of dtypes they support and how do they want to map pytorch type to xla type.

Yeah, if we keep this downcasting behavior, I will consolidate it in the DevicePlugin API (see #6242)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is best to keep fp32 since it is currently supported as in https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/arch/neuron-features/data-types.html.

return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
Expand All @@ -175,20 +164,17 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
return xla::PrimitiveType::C128;
default:
return type;
}
Expand Down
19 changes: 9 additions & 10 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2);
xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result);
if (type1 == type2 && type1 != result_type) {
return ConvertTo(result, result_type, type1, /*device=*/nullptr);
return ConvertTo(result, result_type, type1);
}
return result;
}
Expand Down Expand Up @@ -489,10 +489,10 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(xla::XlaOp op1,
xla::PrimitiveType type2 = TypeOfXlaOp(op2);
xla::PrimitiveType result_type = PromoteType(type1, type2);
if (type1 != result_type) {
op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr);
op1 = ConvertTo(op1, type1, result_type);
}
if (type2 != result_type) {
op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr);
op2 = ConvertTo(op2, type2, result_type);
}
return std::pair<xla::XlaOp, xla::XlaOp>(op1, op2);
}
Expand All @@ -504,13 +504,13 @@ std::tuple<xla::XlaOp, xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(
xla::PrimitiveType type3 = TypeOfXlaOp(op3);
xla::PrimitiveType result_type = PromoteType(type1, type2, type3);
if (type1 != result_type) {
op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr);
op1 = ConvertTo(op1, type1, result_type);
}
if (type2 != result_type) {
op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr);
op2 = ConvertTo(op2, type2, result_type);
}
if (type3 != result_type) {
op3 = ConvertTo(op3, type3, result_type, /*device=*/nullptr);
op3 = ConvertTo(op3, type3, result_type);
}
return std::tuple<xla::XlaOp, xla::XlaOp, xla::XlaOp>(op1, op2, op3);
}
Expand All @@ -519,10 +519,9 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteSecondValue(
xla::XlaOp op1, xla::XlaOp op2) {
xla::PrimitiveType type1 = TypeOfXlaOp(op1);
xla::PrimitiveType type2 = TypeOfXlaOp(op2);
return type1 == type2
? std::pair<xla::XlaOp, xla::XlaOp>(op1, op2)
: std::pair<xla::XlaOp, xla::XlaOp>(
op1, ConvertTo(op2, type2, type1, /*device=*/nullptr));
return type1 == type2 ? std::pair<xla::XlaOp, xla::XlaOp>(op1, op2)
: std::pair<xla::XlaOp, xla::XlaOp>(
op1, ConvertTo(op2, type2, type1));
}

xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input,
xla::XlaOp diag_input = input;
if (target_shape->element_type() != input_shape.element_type()) {
diag_input = ConvertTo(input, input_shape.element_type(),
target_shape->element_type(), /*device=*/nullptr);
target_shape->element_type());
}
std::vector<int64_t> permutation;
xla::XlaOp diag_target = target;
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const {
stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type();
xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_;
xla::XlaOp output =
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to,
/*device=*/nullptr);
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to);
return ReturnOp(output, loctx);
}

Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,8 @@ torch::lazy::NodePtr Clamp(const torch::lazy::Value& input,
xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2));
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type,
/*device=*/nullptr);
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type,
/*device=*/nullptr);
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type);
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type);
return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max},
Expand Down Expand Up @@ -412,7 +410,7 @@ torch::lazy::NodePtr Where(const torch::lazy::Value& condition,
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(2));
xla::XlaOp pred_condition =
ConvertTo(xla_condition, XlaHelpers::TypeOfXlaOp(xla_condition),
xla::PrimitiveType::PRED, /*device=*/nullptr);
xla::PrimitiveType::PRED);
auto promoted_branches = XlaHelpers::ValidateShapes(xla_input, xla_other);
return node.ReturnOp(xla::Select(pred_condition, promoted_branches.first,
promoted_branches.second),
Expand Down
15 changes: 5 additions & 10 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ torch_xla::XlaOpVector Acos::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Acos(xla_input), loctx);
}
Expand Down Expand Up @@ -739,8 +738,7 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Sinh(xla_input), loctx);
}
Expand Down Expand Up @@ -769,8 +767,7 @@ torch_xla::XlaOpVector Sqrt::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Sqrt(xla_input), loctx);
}
Expand All @@ -786,8 +783,7 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Tan(xla_input), loctx);
}
Expand All @@ -796,8 +792,7 @@ torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Tanh(xla_input), loctx);
}
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ xla::XlaOp LowerProd(xla::XlaOp input, const std::vector<int64_t>& dimensions,
xla::XlaOp casted_input;
if (dtype) {
casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
MakeXlaPrimitiveType(*dtype, /*device=*/nullptr),
/*device=*/nullptr);
MakeXlaPrimitiveType(*dtype, /*device=*/nullptr));
} else {
casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input));
}
Expand Down
Loading