diff --git a/test/eager/test_eager.py b/test/eager/test_eager.py index 4320655c066..552382a2dc3 100644 --- a/test/eager/test_eager.py +++ b/test/eager/test_eager.py @@ -5,6 +5,9 @@ import torch_xla import torch_xla.debug.metrics as met import torch_xla.core.xla_model as xm +import torch_xla.debug.metrics as met + +import torch.nn as nn class Eager(unittest.TestCase): @@ -14,6 +17,7 @@ def setUpClass(cls): torch_xla.experimental.eager_mode(True) def test_eager_basic(self): + xm.wait_device_ops() met.clear_all() self.assertTrue(torch_xla.experimental.is_eager_mode()) device = torch_xla.device() @@ -89,6 +93,38 @@ def test_eager_set_random_seed(self): t2 = torch.randn(12, 13, device=device) self.assertTrue(torch.allclose(t1.cpu(), t2.cpu())) + def test_batch_norm_execute_once(self): + xm.wait_device_ops() + device = torch_xla.device() + m = nn.BatchNorm2d(16).to(device) + m.train() + input = torch.randn(8, 16, 8, 32).to(device) + met.clear_all() + output = m(input) + self.assertIn('xla::native_batch_norm', met.counter_names()) + # native_batch_norm has 5 outputs, we should execute them in one graph. + # However in the egaer mode there will be 2 more graphs to broadcast + # 0 to the `u8[0]` and `f32[8,16,8,32]` and one more graph to generate + # s64[]. + self.assertLessEqual(met.metric_data('EagerOpExecuteTime')[0], 4) + # make sure running_mean becomes a XLA_Data + self.assertIn('Data Shape: f32[16]', + torch_xla._XLAC._get_xla_tensor_debug_info(m.running_mean)) + + def test_svd_execute_once(self): + device = torch_xla.device() + a = torch.randn(5, 3).to(device) + xm.wait_device_ops() + met.clear_all() + u, s, v = torch.svd(a) + self.assertIn('xla::_linalg_svd', met.counter_names()) + # svd has 3 outputs, we should execute them in one graph. However in the + # eager mode there will 1 more graph to create an empty f32[5,3] and 2 more + # graphs to transpose the results. + self.assertLessEqual(met.metric_data('EagerOpExecuteTime')[0], 4) + self.assertIn('Data Shape: f32[5,3]', + torch_xla._XLAC._get_xla_tensor_debug_info(u)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index a8eeef74ff8..9baaeb04a53 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -80,12 +80,14 @@ XLATensorPtr XLATensor::Create( XLATensorPtr XLATensor::Create( torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device, - std::optional logical_element_type) { + std::optional logical_element_type, + bool delay_eager_executation) { XLATensorPtr xtensor = c10::make_intrusive( XLATensor(std::move(ir_value), device, logical_element_type)); XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); graph_executor->RegisterTensor(xtensor->data()); - if (UseEagerDebugMode() || graph_executor->UseEagerMode()) { + if ((UseEagerDebugMode() || graph_executor->UseEagerMode()) && + !delay_eager_executation) { std::vector xtensors({xtensor}); graph_executor->ApplyEagerSync(xtensors); } @@ -620,26 +622,32 @@ torch::lazy::Value XLATensor::MaybeCastIrValue( return ir_value; } -XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value) const { +XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value, + bool delay_eager_executation) const { ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(), /*logical_element_type=*/std::nullopt); - return Create(std::move(ir_value), GetDevice(), dtype_optional()); + return Create(std::move(ir_value), GetDevice(), dtype_optional(), + delay_eager_executation); } XLATensorPtr XLATensor::CreateFrom( torch::lazy::Value ir_value, - std::optional logical_element_type_opt) const { + std::optional logical_element_type_opt, + bool delay_eager_executation) const { ir_value = MaybeCastIrValue(std::move(ir_value), GetDevice(), logical_element_type_opt); - return Create(std::move(ir_value), GetDevice(), logical_element_type_opt); + return Create(std::move(ir_value), GetDevice(), logical_element_type_opt, + delay_eager_executation); } XLATensorPtr XLATensor::CreateFrom(torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device, - at::ScalarType logical_element_type) const { + at::ScalarType logical_element_type, + bool delay_eager_executation) const { ir_value = MaybeCastIrValue(std::move(ir_value), device, logical_element_type); - return Create(std::move(ir_value), device, logical_element_type); + return Create(std::move(ir_value), device, logical_element_type, + delay_eager_executation); } void XLATensor::ApplyPendingGraph() { diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index cf5051f5f3f..d837f2c2ab5 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -156,19 +156,23 @@ class XLATensor : public torch::lazy::LazyTensor { std::optional logical_element_type = std::nullopt); static XLATensorPtr Create( torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device, - std::optional logical_element_type = std::nullopt); + std::optional logical_element_type = std::nullopt, + bool delay_eager_executation = false); static XLATensorPtr Create(std::shared_ptr data); // Create a new XLA tensor with the same metadata of the input tensor (with // possible overrides), and the new IR value. - XLATensorPtr CreateFrom(torch::lazy::Value ir_value) const; + XLATensorPtr CreateFrom(torch::lazy::Value ir_value, + bool delay_eager_executation = false) const; XLATensorPtr CreateFrom( torch::lazy::Value ir_value, - std::optional logical_element_type_opt) const; + std::optional logical_element_type_opt, + bool delay_eager_executation = false) const; // TODO: We should remove this one once MaybeCastIrValue is no longer needed. XLATensorPtr CreateFrom(torch::lazy::Value ir_value, const torch::lazy::BackendDevice& device, - at::ScalarType logical_element_type) const; + at::ScalarType logical_element_type, + bool delay_eager_executation = false) const; // The default ctor previously created a null LazyTensor (one with no 'data' // obj). Creating a null XLATensor is no longer possible, since the same can diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 7a39a42ecf4..1c3d445b823 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -587,8 +587,14 @@ std::vector custom_call( std::vector outputs; outputs.reserve(output_shapes.size()); for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back( - inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i])); + outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), + output_dtypes[i], + /*delay_eager_executation=*/true)); + } + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `customcall` and in one graph + graph_executor->ApplyEagerSync(outputs); } return outputs; } @@ -629,8 +635,14 @@ std::vector gpu_custom_call( std::vector outputs; outputs.reserve(output_shapes.size()); for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back( - inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i])); + outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), + output_dtypes[i], + /*delay_eager_executation=*/true)); + } + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `custom` and in one hlo + graph_executor->ApplyEagerSync(outputs); } return outputs; } @@ -662,8 +674,14 @@ std::vector tpu_custom_call( std::vector outputs; outputs.reserve(output_shapes.size()); for (size_t i = 0; i < output_shapes.size(); ++i) { - outputs.push_back( - inputs[0]->CreateFrom(torch::lazy::Value(node, i), output_dtypes[i])); + outputs.push_back(inputs[0]->CreateFrom(torch::lazy::Value(node, i), + output_dtypes[i], + /*delay_eager_executation=*/true)); + } + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `custom` and in one hlo + graph_executor->ApplyEagerSync(outputs); } return outputs; } @@ -715,9 +733,19 @@ void sgd_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step, lr_value, dampening_value, /*use_weight_decay=*/weight_decay != 0, /*use_momentum=*/momentum != 0, /*use_nesterov=*/nesterov); - step->SetInPlaceIrValue(torch::lazy::Value(node, 0)); - param->SetInPlaceIrValue(torch::lazy::Value(node, 1)); - buf->SetInPlaceIrValue(torch::lazy::Value(node, 2)); + step->SetInPlaceIrValue(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + param->SetInPlaceIrValue(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + buf->SetInPlaceIrValue(torch::lazy::Value(node, 2), + /*delay_eager_executation=*/true); + + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `sgd_optimizer_step_` and in one hlo + std::vector tensors_to_sync = {step, param, buf}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } } void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step, @@ -747,12 +775,27 @@ void adam_optimizer_step_(const XLATensorPtr& found_inf, XLATensorPtr& step, weight_decay_value, eps_value, /*use_weight_decay=*/weight_decay != 0, /*use_amsgrad=*/amsgrad, /*use_adamw=*/use_adamw); - step->SetInPlaceIrValue(torch::lazy::Value(node, 0)); - param->SetInPlaceIrValue(torch::lazy::Value(node, 1)); - exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2)); - exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3)); + step->SetInPlaceIrValue(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + param->SetInPlaceIrValue(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + exp_avg->SetInPlaceIrValue(torch::lazy::Value(node, 2), + /*delay_eager_executation=*/true); + exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 3), + /*delay_eager_executation=*/true); if (amsgrad) { - max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4)); + max_exp_avg_sq->SetInPlaceIrValue(torch::lazy::Value(node, 4), + /*delay_eager_executation=*/true); + } + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `adam_optimizer_step_` and in one hlo + std::vector tensors_to_sync = {step, param, exp_avg, + exp_avg_sq}; + if (amsgrad) { + tensors_to_sync.push_back(max_exp_avg_sq); + } + graph_executor->ApplyEagerSync(tensors_to_sync); } } @@ -819,9 +862,17 @@ std::tuple adaptive_max_pool2d( const XLATensorPtr& input, std::vector output_size) { torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), output_size); - XLATensorPtr out = input->CreateFrom(torch::lazy::Value(node, 0)); + XLATensorPtr out = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); XLATensorPtr indices = - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long); + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `adaptive_max_pool2d` and in one hlo + std::vector tensors_to_sync = {out, indices}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } return std::make_tuple(std::move(out), std::move(indices)); } @@ -855,9 +906,19 @@ void _amp_foreach_non_finite_check_and_unscale_(std::vector self, torch::lazy::MakeNode( inputs, found_inf->GetIrValue(), new_inv_scale->GetIrValue()); for (size_t i = 0; i < self.size(); ++i) { - self[i]->SetInPlaceIrValue(torch::lazy::Value(node, i)); + self[i]->SetInPlaceIrValue(torch::lazy::Value(node, i), + /*delay_eager_executation=*/true); + } + found_inf->SetInPlaceIrValue(torch::lazy::Value(node, self.size()), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the + // `_amp_foreach_non_finite_check_and_unscale_` and in one hlo + std::vector tensors_to_sync = self; + tensors_to_sync.push_back(found_inf); + graph_executor->ApplyEagerSync(tensors_to_sync); } - found_inf->SetInPlaceIrValue(torch::lazy::Value(node, self.size())); } void _amp_update_scale_(XLATensorPtr& current_scale, @@ -869,8 +930,16 @@ void _amp_update_scale_(XLATensorPtr& current_scale, growth_tracker->GetIrValue(), current_scale->GetIrValue(), found_inf->GetIrValue(), scale_growth_factor, scale_backoff_factor, growth_interval); - growth_tracker->SetInPlaceIrValue(torch::lazy::Value(node, 1)); - current_scale->SetInPlaceIrValue(torch::lazy::Value(node, 0)); + growth_tracker->SetInPlaceIrValue(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + current_scale->SetInPlaceIrValue(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `_amp_update_scale_` and in one hlo + std::vector tensors_to_sync = {growth_tracker, current_scale}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } } XLATensorPtr abs(const XLATensorPtr& input) { @@ -1194,12 +1263,20 @@ convolution_backward_overrideable( out_backprop->GetIrValue(), input->GetIrValue(), weight->GetIrValue(), std::move(stride), std::move(padding), std::move(dilation), transposed, std::move(output_padding), groups); - XLATensorPtr grad_input = - out_backprop->CreateFrom(torch::lazy::Value(node, 0)); - XLATensorPtr grad_weight = - out_backprop->CreateFrom(torch::lazy::Value(node, 1)); - XLATensorPtr grad_bias = - out_backprop->CreateFrom(torch::lazy::Value(node, 2)); + XLATensorPtr grad_input = out_backprop->CreateFrom( + torch::lazy::Value(node, 0), /*delay_eager_executation=*/true); + XLATensorPtr grad_weight = out_backprop->CreateFrom( + torch::lazy::Value(node, 1), /*delay_eager_executation=*/true); + XLATensorPtr grad_bias = out_backprop->CreateFrom( + torch::lazy::Value(node, 2), /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `convolution_backward_overrideable` and + // in one hlo + std::vector tensors_to_sync = {grad_input, grad_weight, + grad_bias}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } @@ -1364,9 +1441,17 @@ std::tuple einsum_backward( grad_output->GetIrValue(), irs, equation); if (node->num_outputs() == 2) { - return std::make_tuple( - grad_output->CreateFrom(torch::lazy::Value(node, 0)), - grad_output->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = grad_output->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = grad_output->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `einsum_backward` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } else { return std::make_tuple(grad_output->CreateFrom(torch::lazy::Value(node, 0)), XLATensorPtr()); @@ -1411,10 +1496,22 @@ embedding_bag(const XLATensorPtr& weight, const XLATensorPtr& indices, torch::lazy::NodePtr node = torch::lazy::MakeNode( weight->GetIrValue(), indices->GetIrValue(), offsets->GetIrValue(), mode, per_sample_weights->GetIrValue(), include_last_offset); - return std::make_tuple(weight->CreateFrom(torch::lazy::Value(node, 0)), - weight->CreateFrom(torch::lazy::Value(node, 1)), - weight->CreateFrom(torch::lazy::Value(node, 2)), - weight->CreateFrom(torch::lazy::Value(node, 3))); + + XLATensorPtr t1 = weight->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = weight->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLATensorPtr t3 = weight->CreateFrom(torch::lazy::Value(node, 2), + /*delay_eager_executation=*/true); + XLATensorPtr t4 = weight->CreateFrom(torch::lazy::Value(node, 3), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `embedding_bag` and in one hlo + std::vector tensors_to_sync = {t1, t2, t3, t4}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2, t3, t4); } XLATensorPtr exp(const XLATensorPtr& input) { @@ -1693,9 +1790,18 @@ std::tuple kthvalue(const XLATensorPtr& input, input->GetIrValue(), k, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()), keepdim); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `kthvalue` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr le(const XLATensorPtr& input, const at::Scalar& other) { @@ -1908,9 +2014,18 @@ std::tuple max(const XLATensorPtr& input, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `max` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } void max_out(XLATensorPtr& max, XLATensorPtr& max_values, @@ -1919,8 +2034,16 @@ void max_out(XLATensorPtr& max, XLATensorPtr& max_values, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - max->SetIrValue(torch::lazy::Value(node, 0)); - max_values->SetIrValue(torch::lazy::Value(node, 1)); + max->SetIrValue(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + max_values->SetIrValue(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `max_out` and in one hlo + std::vector tensors_to_sync = {max, max_values}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } } std::tuple max_pool_nd( @@ -1933,9 +2056,19 @@ std::tuple max_pool_nd( torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), spatial_dim_count, std::move(kernel_size), std::move(stride), std::move(padding), ceil_mode); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); + + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `max_pool_nd` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr max_pool_nd_backward( @@ -1989,9 +2122,18 @@ std::tuple min(const XLATensorPtr& input, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `min` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } void min_out(XLATensorPtr& min, XLATensorPtr& min_indices, @@ -2000,8 +2142,16 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()); torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), canonical_dim, keepdim); - min->SetIrValue(torch::lazy::Value(node, 0)); - min_indices->SetIrValue(torch::lazy::Value(node, 1)); + min->SetIrValue(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + min_indices->SetIrValue(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `min_out` and in one hlo + std::vector tensors_to_sync = {min, min_indices}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } } XLATensorPtr mish(const XLATensorPtr& input) { @@ -2132,25 +2282,49 @@ std::tuple native_batch_norm( torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), weight_value, bias_value, running_mean_value, running_var_value, training, eps); - XLATensorPtr output = input->CreateFrom(torch::lazy::Value(node, 0)); + XLATensorPtr output = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); XLATensorPtr mean; XLATensorPtr variance_inverse; if (training) { - mean = input->CreateFrom(torch::lazy::Value(node, 1)); - variance_inverse = input->CreateFrom(torch::lazy::Value(node, 3)); + mean = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + variance_inverse = input->CreateFrom(torch::lazy::Value(node, 3), + /*delay_eager_executation=*/true); if (running_mean) { - running_mean->SetIrValue(torch::lazy::MakeNode( - mean->GetIrValue(), running_mean->GetIrValue(), momentum)); + running_mean->SetIrValue( + torch::lazy::MakeNode( + mean->GetIrValue(), running_mean->GetIrValue(), momentum), + /*delay_eager_executation=*/true); } if (running_var) { - running_var->SetIrValue(torch::lazy::MakeNode( - torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum)); + running_var->SetIrValue( + torch::lazy::MakeNode( + torch::lazy::Value(node, 2), running_var->GetIrValue(), momentum), + /*delay_eager_executation=*/true); } } else { at::Tensor at_input = bridge::AtenFromXlaTensor(input); mean = bridge::GetXlaTensor(at::empty({0}, at_input.options())); variance_inverse = bridge::GetXlaTensor(at::empty({0}, at_input.options())); } + + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `native_batch_norm` and in one hlo + std::vector tensors_to_sync = {output}; + if (training) { + tensors_to_sync.push_back(mean); + tensors_to_sync.push_back(variance_inverse); + if (running_mean) { + tensors_to_sync.push_back(running_mean); + } + if (running_var) { + tensors_to_sync.push_back(running_var); + } + } + graph_executor->ApplyEagerSync(tensors_to_sync); + } return std::make_tuple(std::move(output), std::move(mean), std::move(variance_inverse)); } @@ -2165,9 +2339,20 @@ std::tuple native_batch_norm_backward( torch::lazy::NodePtr node = torch::lazy::MakeNode( grad_out->GetIrValue(), input->GetIrValue(), weight_value, save_mean->GetIrValue(), save_invstd->GetIrValue(), training, eps); - XLATensorPtr grad_input = input->CreateFrom(torch::lazy::Value(node, 0)); - XLATensorPtr grad_weight = input->CreateFrom(torch::lazy::Value(node, 1)); - XLATensorPtr grad_bias = input->CreateFrom(torch::lazy::Value(node, 2)); + XLATensorPtr grad_input = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr grad_weight = input->CreateFrom( + torch::lazy::Value(node, 1), /*delay_eager_executation=*/true); + XLATensorPtr grad_bias = input->CreateFrom(torch::lazy::Value(node, 2), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `native_batch_norm_backward` and in one + // hlo + std::vector tensors_to_sync = {grad_input, grad_weight, + grad_bias}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } return std::make_tuple(std::move(grad_input), std::move(grad_weight), std::move(grad_bias)); } @@ -2177,9 +2362,18 @@ std::tuple native_dropout( torch::lazy::NodePtr node = torch::lazy::MakeNode( input->GetIrValue(), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), p, train); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Bool)); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Bool, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `native_dropout` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr ne(const XLATensorPtr& input, const at::Scalar& other) { @@ -2377,8 +2571,17 @@ std::tuple prelu_backward( const XLATensorPtr& weight) { torch::lazy::NodePtr node = PreluBackward( grad->GetIrValue(), input->GetIrValue(), weight->GetIrValue()); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `prelu_backward` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr prod(const XLATensorPtr& input, std::vector dimensions, @@ -2408,8 +2611,17 @@ std::tuple qr(const XLATensorPtr& input, bool some) { torch::lazy::NodePtr node = torch::lazy::MakeNode(input->GetIrValue(), some); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `qr` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr quantize_tensor(const XLATensorPtr& input, @@ -2799,8 +3011,17 @@ XLATensorPtr slice(const XLATensorPtr& input, int64_t dim, int64_t start, std::tuple slogdet(const XLATensorPtr& input) { torch::lazy::NodePtr node = SLogDet(input->GetIrValue()); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `slogdet` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr smooth_l1_loss(const XLATensorPtr& input, @@ -2955,8 +3176,17 @@ std::tuple std_mean(const XLATensorPtr& input, torch_xla::runtime::util::ToVector(dimensions), input->shape().get().rank()), correction, keep_reduced_dimensions); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `std_mean` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr sub(const XLATensorPtr& input, const XLATensorPtr& other, @@ -3020,9 +3250,19 @@ std::tuple svd( const XLATensorPtr& input, bool some, bool compute_uv) { torch::lazy::NodePtr node = torch::lazy::MakeNode(input->GetIrValue(), some, compute_uv); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1)), - input->CreateFrom(torch::lazy::Value(node, 2))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLATensorPtr t3 = input->CreateFrom(torch::lazy::Value(node, 2), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `svd` and in one hlo + std::vector tensors_to_sync = {t1, t2, t3}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2, t3); } XLATensorPtr tanh_backward(const XLATensorPtr& grad_output, @@ -3071,9 +3311,18 @@ std::tuple topk(const XLATensorPtr& input, input->GetIrValue(), k, torch::lazy::GetCanonicalDimensionIndex(dim, input->shape().get().rank()), largest, sorted, stable); - return std::make_tuple( - input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long)); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = + input->CreateFrom(torch::lazy::Value(node, 1), at::ScalarType::Long, + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `topk` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } XLATensorPtr trace(const XLATensorPtr& input) { @@ -3128,8 +3377,17 @@ std::tuple triangular_solve( torch::lazy::NodePtr node = torch::lazy::MakeNode( rhs->GetIrValue(), lhs->GetIrValue(), left_side, !upper, transpose, unitriangular); - return std::make_tuple(rhs->CreateFrom(torch::lazy::Value(node, 0)), - rhs->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = rhs->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = rhs->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `triangular_solve` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } std::vector unbind(const XLATensorPtr& input, int64_t dim) { @@ -3279,8 +3537,17 @@ std::tuple var_mean(const XLATensorPtr& input, torch_xla::runtime::util::ToVector(dimensions), input->shape().get().rank()), correction, keep_reduced_dimensions); - return std::make_tuple(input->CreateFrom(torch::lazy::Value(node, 0)), - input->CreateFrom(torch::lazy::Value(node, 1))); + XLATensorPtr t1 = input->CreateFrom(torch::lazy::Value(node, 0), + /*delay_eager_executation=*/true); + XLATensorPtr t2 = input->CreateFrom(torch::lazy::Value(node, 1), + /*delay_eager_executation=*/true); + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `var_mean` and in one hlo + std::vector tensors_to_sync = {t1, t2}; + graph_executor->ApplyEagerSync(tensors_to_sync); + } + return std::make_tuple(t1, t2); } void zero_(XLATensorPtr& input) {