diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py new file mode 100644 index 00000000000..ebac882efdf --- /dev/null +++ b/test/pjrt/test_dtypes.py @@ -0,0 +1,35 @@ +from absl.testing import absltest, parameterized +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + + +class TestDtypes(parameterized.TestCase): + + @parameterized.parameters(torch.float16, torch.float32, torch.float64, + torch.bfloat16, torch.complex64) + def test_float_round_trip(self, dtype: torch.dtype): + t = torch.randn((3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + @parameterized.parameters( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) + def test_int_round_trip(self, dtype: torch.dtype): + t = torch.randint(0, 128, (3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + def test_bool_round_trip(self): + t = torch.randint(0, 2, (3, 3), dtype=torch.bool) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/run_tests.sh b/test/run_tests.sh index 4e5dc6e90f6..1553d53e409 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -128,7 +128,7 @@ function run_torchrun { echo "Running torchrun test for GPU $@" num_devices=$(nvidia-smi --list-gpus | wc -l) PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@ - fi + fi } function run_torch_op_tests { @@ -190,6 +190,7 @@ function run_xla_op_tests1 { # DO NOT MODIFY function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" + run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index faf2d67ba08..c65b2e5692e 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -57,6 +57,7 @@ spec: python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py + python3 /src/pytorch/xla/test/pjrt/test_dtypes.py python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py volumeMounts: - mountPath: /dev/shm diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index a920bdb69e9..cd86e0f3169 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -15,11 +15,6 @@ namespace torch_xla { namespace { -xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { - xla::XlaOp zero = xla::Zero(op.builder(), from); - return xla::Ne(op, zero); -} - xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size, int64_t narrow_size) { uint64_t mask_value = @@ -53,50 +48,20 @@ xla::XlaOp ConvertData(xla::XlaOp op, xla::PrimitiveType type, } // namespace xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType to) { if (from == to) { return op; } - XlaDeviceType hw_type = - static_cast(bridge::GetDeviceOrCurrent(device).type()); - if (hw_type != XlaDeviceType::TPU) { - return xla::ConvertElementType(op, to); - } - switch (from) { - case xla::PrimitiveType::PRED: - case xla::PrimitiveType::S8: - case xla::PrimitiveType::U8: - case xla::PrimitiveType::S16: - case xla::PrimitiveType::U16: - case xla::PrimitiveType::S32: - case xla::PrimitiveType::U32: - case xla::PrimitiveType::BF16: - case xla::PrimitiveType::F32: - return xla::ConvertElementType(op, to); - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U64: { - switch (to) { - case xla::PrimitiveType::PRED: - return ExplicitBooleanConvert(op, from); - default: - return xla::ConvertElementType(op, to); - } - break; - } - default: - XLA_ERROR() << "Unsupported XLA type " << from; - } + return xla::ConvertElementType(op, to); } xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType raw_to) { if (from != raw_from) { op = ConvertData(op, from, raw_from); } - xla::XlaOp result = ConvertTo(op, from, to, device); + xla::XlaOp result = ConvertTo(op, from, to); return to == raw_to ? result : ConvertData(result, to, raw_to); } @@ -105,8 +70,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); op = ConvertTo( op, from, - MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device), - &xla_device); + MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device)); } return op; } @@ -120,7 +84,7 @@ xla::XlaOp CastToScalarType(xla::XlaOp input, if (dtype) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, &xla_device), &xla_device); + MakeXlaPrimitiveType(*dtype, &xla_device)); } return ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/convert_ops.h b/torch_xla/csrc/convert_ops.h index 3dd0ce99f3a..029599667bd 100644 --- a/torch_xla/csrc/convert_ops.h +++ b/torch_xla/csrc/convert_ops.h @@ -11,13 +11,11 @@ namespace torch_xla { xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType to); xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType raw_to); xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from); @@ -32,4 +30,4 @@ xla::XlaOp MaybeConvertTo(xla::XlaOp input, xla::PrimitiveType type); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index bb02ca7da2e..3eb04baa8e8 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -160,7 +160,7 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp mask_pred = xla::Ne(mask, zero); xla::XlaOp update_scalar = ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(), - ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr); + ShapeHelper::ShapeOfXlaOp(input).element_type()); return xla::Select(mask_pred, update_scalar, input); } @@ -291,7 +291,7 @@ xla::XlaOp BuildUpdateSlice(xla::XlaOp input, xla::XlaOp source, xla::XlaOp update_source = source; if (source_shape.element_type() != input_shape.element_type()) { update_source = ConvertTo(source, source_shape.element_type(), - input_shape.element_type(), /*device=*/nullptr); + input_shape.element_type()); } xla::XlaOp reshaped_source = XlaHelpers::ReshapeToRank(update_source, input_shape.rank()); diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 103630bcec8..d4ed1b413a6 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -75,16 +75,6 @@ bool Use32BitLong() { return use_32bit_long; } -bool IsTpuDevice(XlaDeviceType hw_type) { - static bool spmd_device_is_tpu = - (hw_type == XlaDeviceType::SPMD) && - // HACK: find a better way to decide if SPMD is actually a TPU without - // accessing the runtime. - runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != - std::string::npos; - return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; -} - } // namespace at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { @@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( if (UseBF16()) { return xla::PrimitiveType::BF16; } - if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) || - hw_type == XlaDeviceType::NEURON) { + if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) { return xla::PrimitiveType::F32; } return xla::PrimitiveType::F64; @@ -175,20 +164,17 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16 + : xla::PrimitiveType::U32; case xla::PrimitiveType::S16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16 + : xla::PrimitiveType::S32; case xla::PrimitiveType::S64: return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: - return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 - : xla::PrimitiveType::C64; + return xla::PrimitiveType::C128; default: return type; } diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 1e4f3be9f78..895b9f9279e 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -29,7 +29,7 @@ xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2); xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result); if (type1 == type2 && type1 != result_type) { - return ConvertTo(result, result_type, type1, /*device=*/nullptr); + return ConvertTo(result, result_type, type1); } return result; } @@ -489,10 +489,10 @@ std::pair XlaHelpers::PromoteValues(xla::XlaOp op1, xla::PrimitiveType type2 = TypeOfXlaOp(op2); xla::PrimitiveType result_type = PromoteType(type1, type2); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } return std::pair(op1, op2); } @@ -504,13 +504,13 @@ std::tuple XlaHelpers::PromoteValues( xla::PrimitiveType type3 = TypeOfXlaOp(op3); xla::PrimitiveType result_type = PromoteType(type1, type2, type3); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } if (type3 != result_type) { - op3 = ConvertTo(op3, type3, result_type, /*device=*/nullptr); + op3 = ConvertTo(op3, type3, result_type); } return std::tuple(op1, op2, op3); } @@ -519,10 +519,9 @@ std::pair XlaHelpers::PromoteSecondValue( xla::XlaOp op1, xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); - return type1 == type2 - ? std::pair(op1, op2) - : std::pair( - op1, ConvertTo(op2, type2, type1, /*device=*/nullptr)); + return type1 == type2 ? std::pair(op1, op2) + : std::pair( + op1, ConvertTo(op2, type2, type1)); } xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1, diff --git a/torch_xla/csrc/matrix.cpp b/torch_xla/csrc/matrix.cpp index 0cffa5c8d2a..eccfc759a3d 100644 --- a/torch_xla/csrc/matrix.cpp +++ b/torch_xla/csrc/matrix.cpp @@ -110,7 +110,7 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, xla::XlaOp diag_input = input; if (target_shape->element_type() != input_shape.element_type()) { diag_input = ConvertTo(input, input_shape.element_type(), - target_shape->element_type(), /*device=*/nullptr); + target_shape->element_type()); } std::vector permutation; xla::XlaOp diag_target = target; diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 95068640a27..f1a0a1a9072 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -55,8 +55,7 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const { stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type(); xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_; xla::XlaOp output = - ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to, - /*device=*/nullptr); + ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to); return ReturnOp(output, loctx); } diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 494f79c2b09..0e1e885dde4 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -262,10 +262,8 @@ torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type, - /*device=*/nullptr); - xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type, - /*device=*/nullptr); + xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type); + xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type); return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, @@ -412,7 +410,7 @@ torch::lazy::NodePtr Where(const torch::lazy::Value& condition, xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(2)); xla::XlaOp pred_condition = ConvertTo(xla_condition, XlaHelpers::TypeOfXlaOp(xla_condition), - xla::PrimitiveType::PRED, /*device=*/nullptr); + xla::PrimitiveType::PRED); auto promoted_branches = XlaHelpers::ValidateShapes(xla_input, xla_other); return node.ReturnOp(xla::Select(pred_condition, promoted_branches.first, promoted_branches.second), diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 286b0207717..e45c7782eb5 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -25,8 +25,7 @@ torch_xla::XlaOpVector Acos::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Acos(xla_input), loctx); } @@ -739,8 +738,7 @@ torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Sinh(xla_input), loctx); } @@ -769,8 +767,7 @@ torch_xla::XlaOpVector Sqrt::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Sqrt(xla_input), loctx); } @@ -786,8 +783,7 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tan(xla_input), loctx); } @@ -796,8 +792,7 @@ torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tanh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index e61bc9541ea..1790d477cce 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -20,8 +20,7 @@ xla::XlaOp LowerProd(xla::XlaOp input, const std::vector& dimensions, xla::XlaOp casted_input; if (dtype) { casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, /*device=*/nullptr), - /*device=*/nullptr); + MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); } else { casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 34722987954..29230546844 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -170,7 +170,7 @@ xla::XlaOp CreateIndexAlongDim( xla::XlaOp updates = value; if (buffer_shape.element_type() != value_shape.element_type()) { updates = ConvertTo(updates, value_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } if (broadcast_value_to_index) { const xla::Shape& index_shape = ShapeHelper::ShapeOfXlaOp(index); @@ -603,7 +603,7 @@ xla::XlaOp CreateIndexUpdate( xla::XlaOp new_values = values; if (buffer_shape.element_type() != values_shape.element_type()) { new_values = ConvertTo(new_values, values_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } new_values = BuildExpand(new_values, expected_values_dims); const xla::Shape& new_values_shape = ShapeHelper::ShapeOfXlaOp(new_values); @@ -654,8 +654,7 @@ XlaOpCombiner NumericAddCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x + numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -665,8 +664,7 @@ XlaOpCombiner NumericMulCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x * numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -677,8 +675,7 @@ XlaOpCombiner NumericMinCombiner() { xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); // xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -688,8 +685,7 @@ XlaOpCombiner NumericMaxCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = xla::Max(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; }