Skip to content

Commit

Permalink
Remove obsolete device special cases (#6197)
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar authored and golechwierowicz committed Jan 12, 2024
1 parent f9f84fe commit 1df07f8
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 110 deletions.
35 changes: 35 additions & 0 deletions test/pjrt/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from absl.testing import absltest, parameterized
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr


class TestDtypes(parameterized.TestCase):

@parameterized.parameters(torch.float16, torch.float32, torch.float64,
torch.bfloat16, torch.complex64)
def test_float_round_trip(self, dtype: torch.dtype):
t = torch.randn((3, 3), dtype=dtype)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)

@parameterized.parameters(
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
)
def test_int_round_trip(self, dtype: torch.dtype):
t = torch.randint(0, 128, (3, 3), dtype=dtype)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)

def test_bool_round_trip(self):
t = torch.randint(0, 2, (3, 3), dtype=torch.bool)
xt = t.to(xm.xla_device())
torch.testing.assert_close(xt.cpu(), t)


if __name__ == "__main__":
absltest.main()
3 changes: 2 additions & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ function run_torchrun {
echo "Running torchrun test for GPU $@"
num_devices=$(nvidia-smi --list-gpus | wc -l)
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@
fi
fi
}

function run_torch_op_tests {
Expand Down Expand Up @@ -190,6 +190,7 @@ function run_xla_op_tests1 {
# DO NOT MODIFY
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
1 change: 1 addition & 0 deletions test/tpu/xla_test_job.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ spec:
python3 /src/pytorch/xla/test/test_autocast.py
python3 /src/pytorch/xla/test/dynamo/test_dynamo.py
python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py
python3 /src/pytorch/xla/test/pjrt/test_dtypes.py
python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py
volumeMounts:
- mountPath: /dev/shm
Expand Down
48 changes: 6 additions & 42 deletions torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
namespace torch_xla {
namespace {

xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) {
xla::XlaOp zero = xla::Zero(op.builder(), from);
return xla::Ne(op, zero);
}

xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size,
int64_t narrow_size) {
uint64_t mask_value =
Expand Down Expand Up @@ -53,50 +48,20 @@ xla::XlaOp ConvertData(xla::XlaOp op, xla::PrimitiveType type,
} // namespace

xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to,
const torch::lazy::BackendDevice* device) {
xla::PrimitiveType to) {
if (from == to) {
return op;
}
XlaDeviceType hw_type =
static_cast<XlaDeviceType>(bridge::GetDeviceOrCurrent(device).type());
if (hw_type != XlaDeviceType::TPU) {
return xla::ConvertElementType(op, to);
}
switch (from) {
case xla::PrimitiveType::PRED:
case xla::PrimitiveType::S8:
case xla::PrimitiveType::U8:
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::BF16:
case xla::PrimitiveType::F32:
return xla::ConvertElementType(op, to);
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64: {
switch (to) {
case xla::PrimitiveType::PRED:
return ExplicitBooleanConvert(op, from);
default:
return xla::ConvertElementType(op, to);
}
break;
}
default:
XLA_ERROR() << "Unsupported XLA type " << from;
}
return xla::ConvertElementType(op, to);
}

xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType raw_from, xla::PrimitiveType to,
xla::PrimitiveType raw_to,
const torch::lazy::BackendDevice* device) {
xla::PrimitiveType raw_to) {
if (from != raw_from) {
op = ConvertData(op, from, raw_from);
}
xla::XlaOp result = ConvertTo(op, from, to, device);
xla::XlaOp result = ConvertTo(op, from, to);
return to == raw_to ? result : ConvertData(result, to, raw_to);
}

Expand All @@ -105,8 +70,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
op = ConvertTo(
op, from,
MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device),
&xla_device);
MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device));
}
return op;
}
Expand All @@ -120,7 +84,7 @@ xla::XlaOp CastToScalarType(xla::XlaOp input,
if (dtype) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
MakeXlaPrimitiveType(*dtype, &xla_device), &xla_device);
MakeXlaPrimitiveType(*dtype, &xla_device));
}
return ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input));
}
Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/convert_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,11 @@
namespace torch_xla {

xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to,
const torch::lazy::BackendDevice* device);
xla::PrimitiveType to);

xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType raw_from, xla::PrimitiveType to,
xla::PrimitiveType raw_to,
const torch::lazy::BackendDevice* device);
xla::PrimitiveType raw_to);

xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from);

Expand All @@ -32,4 +30,4 @@ xla::XlaOp MaybeConvertTo(xla::XlaOp input, xla::PrimitiveType type);

} // namespace torch_xla

#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_
#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_
4 changes: 2 additions & 2 deletions torch_xla/csrc/data_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
xla::XlaOp mask_pred = xla::Ne(mask, zero);
xla::XlaOp update_scalar =
ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(),
ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr);
ShapeHelper::ShapeOfXlaOp(input).element_type());
return xla::Select(mask_pred, update_scalar, input);
}

Expand Down Expand Up @@ -291,7 +291,7 @@ xla::XlaOp BuildUpdateSlice(xla::XlaOp input, xla::XlaOp source,
xla::XlaOp update_source = source;
if (source_shape.element_type() != input_shape.element_type()) {
update_source = ConvertTo(source, source_shape.element_type(),
input_shape.element_type(), /*device=*/nullptr);
input_shape.element_type());
}
xla::XlaOp reshaped_source =
XlaHelpers::ReshapeToRank(update_source, input_shape.rank());
Expand Down
26 changes: 6 additions & 20 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,6 @@ bool Use32BitLong() {
return use_32bit_long;
}

bool IsTpuDevice(XlaDeviceType hw_type) {
static bool spmd_device_is_tpu =
(hw_type == XlaDeviceType::SPMD) &&
// HACK: find a better way to decide if SPMD is actually a TPU without
// accessing the runtime.
runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") !=
std::string::npos;
return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu;
}

} // namespace

at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) ||
hw_type == XlaDeviceType::NEURON) {
if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) {
return xla::PrimitiveType::F32;
}
return xla::PrimitiveType::F64;
Expand All @@ -175,20 +164,17 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType(
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON
? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
return xla::PrimitiveType::C128;
default:
return type;
}
Expand Down
19 changes: 9 additions & 10 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2);
xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result);
if (type1 == type2 && type1 != result_type) {
return ConvertTo(result, result_type, type1, /*device=*/nullptr);
return ConvertTo(result, result_type, type1);
}
return result;
}
Expand Down Expand Up @@ -489,10 +489,10 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(xla::XlaOp op1,
xla::PrimitiveType type2 = TypeOfXlaOp(op2);
xla::PrimitiveType result_type = PromoteType(type1, type2);
if (type1 != result_type) {
op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr);
op1 = ConvertTo(op1, type1, result_type);
}
if (type2 != result_type) {
op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr);
op2 = ConvertTo(op2, type2, result_type);
}
return std::pair<xla::XlaOp, xla::XlaOp>(op1, op2);
}
Expand All @@ -504,13 +504,13 @@ std::tuple<xla::XlaOp, xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteValues(
xla::PrimitiveType type3 = TypeOfXlaOp(op3);
xla::PrimitiveType result_type = PromoteType(type1, type2, type3);
if (type1 != result_type) {
op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr);
op1 = ConvertTo(op1, type1, result_type);
}
if (type2 != result_type) {
op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr);
op2 = ConvertTo(op2, type2, result_type);
}
if (type3 != result_type) {
op3 = ConvertTo(op3, type3, result_type, /*device=*/nullptr);
op3 = ConvertTo(op3, type3, result_type);
}
return std::tuple<xla::XlaOp, xla::XlaOp, xla::XlaOp>(op1, op2, op3);
}
Expand All @@ -519,10 +519,9 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteSecondValue(
xla::XlaOp op1, xla::XlaOp op2) {
xla::PrimitiveType type1 = TypeOfXlaOp(op1);
xla::PrimitiveType type2 = TypeOfXlaOp(op2);
return type1 == type2
? std::pair<xla::XlaOp, xla::XlaOp>(op1, op2)
: std::pair<xla::XlaOp, xla::XlaOp>(
op1, ConvertTo(op2, type2, type1, /*device=*/nullptr));
return type1 == type2 ? std::pair<xla::XlaOp, xla::XlaOp>(op1, op2)
: std::pair<xla::XlaOp, xla::XlaOp>(
op1, ConvertTo(op2, type2, type1));
}

xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input,
xla::XlaOp diag_input = input;
if (target_shape->element_type() != input_shape.element_type()) {
diag_input = ConvertTo(input, input_shape.element_type(),
target_shape->element_type(), /*device=*/nullptr);
target_shape->element_type());
}
std::vector<int64_t> permutation;
xla::XlaOp diag_target = target;
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const {
stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type();
xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_;
xla::XlaOp output =
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to,
/*device=*/nullptr);
ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to);
return ReturnOp(output, loctx);
}

Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/ops/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,8 @@ torch::lazy::NodePtr Clamp(const torch::lazy::Value& input,
xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1));
xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2));
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type,
/*device=*/nullptr);
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type,
/*device=*/nullptr);
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type);
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type);
return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx);
};
return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max},
Expand Down Expand Up @@ -412,7 +410,7 @@ torch::lazy::NodePtr Where(const torch::lazy::Value& condition,
xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(2));
xla::XlaOp pred_condition =
ConvertTo(xla_condition, XlaHelpers::TypeOfXlaOp(xla_condition),
xla::PrimitiveType::PRED, /*device=*/nullptr);
xla::PrimitiveType::PRED);
auto promoted_branches = XlaHelpers::ValidateShapes(xla_input, xla_other);
return node.ReturnOp(xla::Select(pred_condition, promoted_branches.first,
promoted_branches.second),
Expand Down
15 changes: 5 additions & 10 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ torch_xla::XlaOpVector Acos::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Acos(xla_input), loctx);
}
Expand Down Expand Up @@ -739,8 +738,7 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Sinh(xla_input), loctx);
}
Expand Down Expand Up @@ -769,8 +767,7 @@ torch_xla::XlaOpVector Sqrt::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Sqrt(xla_input), loctx);
}
Expand All @@ -786,8 +783,7 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Tan(xla_input), loctx);
}
Expand All @@ -796,8 +792,7 @@ torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) {
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32,
/*device=*/nullptr);
xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32);
}
return ReturnOp(xla::Tanh(xla_input), loctx);
}
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ xla::XlaOp LowerProd(xla::XlaOp input, const std::vector<int64_t>& dimensions,
xla::XlaOp casted_input;
if (dtype) {
casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
MakeXlaPrimitiveType(*dtype, /*device=*/nullptr),
/*device=*/nullptr);
MakeXlaPrimitiveType(*dtype, /*device=*/nullptr));
} else {
casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input));
}
Expand Down
Loading

0 comments on commit 1df07f8

Please sign in to comment.