diff --git a/examples/eager/train_decoder_only_eager_multi_process.py b/examples/eager/train_decoder_only_eager_multi_process.py new file mode 100644 index 00000000000..c686083fe06 --- /dev/null +++ b/examples/eager/train_decoder_only_eager_multi_process.py @@ -0,0 +1,26 @@ +import sys +import os +example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) +sys.path.append(example_folder) +from train_decoder_only_base import TrainDecoderOnlyBase + +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.core.xla_model as xm + + +class TrainDecoderXLADDP(TrainDecoderOnlyBase): + + def run_optimizer(self): + # optimizer_step will call `optimizer.step()` and all_reduce the gradident + xm.optimizer_step(self.optimizer) + + +def _mp_fn(index): + import torch_xla + torch_xla.experimental.eager_mode(True) + xla_ddp = TrainDecoderXLADDP() + xla_ddp.start_training() + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=()) diff --git a/examples/train_decoder_only_base.py b/examples/train_decoder_only_base.py index 63a6ad028aa..2536d1c2f2b 100644 --- a/examples/train_decoder_only_base.py +++ b/examples/train_decoder_only_base.py @@ -20,7 +20,7 @@ def __init__(self): self.config = DecoderOnlyConfig() self.batch_size = 16 self.seq_len = 512 - self.num_steps = 300 + self.num_steps = 200 self.num_epochs = 1 self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. # For the purpose of this example, we are going to use fake data. diff --git a/test/eager/test_eager_all_reduce_in_place.py b/test/eager/test_eager_all_reduce_in_place.py new file mode 100644 index 00000000000..9ee91da2074 --- /dev/null +++ b/test/eager/test_eager_all_reduce_in_place.py @@ -0,0 +1,40 @@ +import torch +import torch_xla + +import torch_xla.core.xla_model as xm +import torch_xla.debug +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.debug.metrics as met + + +def _mp_fn(index): + import torch_xla + torch_xla.experimental.eager_mode(True) + + device = torch_xla.device() + + if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): + return + + ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device) + ordinal_tensor_2 = torch.tensor([index], dtype=torch.int32).to(device) + xm.wait_device_ops() + met.clear_all() + + # all_reduce with list of tensor as input will be a inplace op. This is + # used by the optimizer_step. + xm.all_reduce(xm.REDUCE_SUM, [ordinal_tensor_1, ordinal_tensor_2]) + + xm.wait_device_ops() + assert met.metric_data("EagerOpExecuteTime")[0] == 1 + + num_device = torch_xla.runtime.global_runtime_device_count() + expected_sum = (num_device - 1) * num_device / 2 + expected_1 = torch.tensor([(expected_sum)], dtype=torch.float) + expected_2 = torch.tensor([(expected_sum)], dtype=torch.int32) + assert torch.allclose(expected_1, ordinal_tensor_1.cpu()) + assert torch.allclose(expected_2, ordinal_tensor_2.cpu()) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=()) diff --git a/test/run_tests.sh b/test/run_tests.sh index 5b2d5bdadc8..9f641e8effe 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -207,6 +207,7 @@ function run_xla_op_tests2 { run_test "$CDIR/eager/test_eager.py" run_test "$CDIR/eager/test_eager_with_xla_compile.py" run_test "$CDIR/eager/test_eager_with_torch_compile.py" + run_test "$CDIR/eager/test_eager_all_reduce_in_place.py" } # All the new xla op tests should go to run_xla_op_tests3 diff --git a/test/tpu/run_tests.sh b/test/tpu/run_tests.sh index dfff128a5ab..f6b16adfc10 100755 --- a/test/tpu/run_tests.sh +++ b/test/tpu/run_tests.sh @@ -43,4 +43,5 @@ TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; prin if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then python3 examples/eager/train_decoder_only_eager.py python3 examples/eager/train_decoder_only_eager_with_compile.py + python3 examples/eager/train_decoder_only_eager_multi_process.py fi diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 13eb94ae30d..a8eeef74ff8 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -351,7 +351,8 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) { data()->is_cloned = false; } -void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) { +void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value, + bool delay_eager_executation) { auto xla_shape = shape(); if (xla_shape.get().element_type() != GetXlaShape(ir_value).element_type()) { ir_value = @@ -361,7 +362,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) { XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); // in place update should also be triggered eagerly if configured - if (graph_executor->UseEagerMode()) { + if (graph_executor->UseEagerMode() && !delay_eager_executation) { std::vector xtensors({c10::make_intrusive(*this)}); graph_executor->ApplyEagerSync(xtensors); } diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 101c6c54b75..cf5051f5f3f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -228,7 +228,8 @@ class XLATensor : public torch::lazy::LazyTensor { // TODO(alanwaketan): Reuse the upstream ones once Functionalization is done. torch::lazy::Value GetIrValue() const; void SetIrValue(torch::lazy::Value ir_value, bool inplace = true); - void SetInPlaceIrValue(torch::lazy::Value ir_value); + void SetInPlaceIrValue(torch::lazy::Value ir_value, + bool delay_eager_executation = false); // TODO(alanwaketan): Reuse the upstream one once Functionalization is done. std::optional CurrentTensorData() const; diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index cc1b34e9043..79b29b6e05f 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -370,10 +370,26 @@ void all_reduce(const std::vector& inputs, reduce_type, input_values, GetAllReduceToken(inputs.front()->GetDevice()), scale, std::move(groups), pin_layout); for (size_t i = 0; i < inputs.size(); ++i) { - inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i)); + // In eager mode we don't want to execute the IR for each tensor because + // that will execute the `all_reduce` x times. + inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i), + /*delay_eager_executation=*/true); + } + + XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get(); + if (graph_executor->UseEagerMode()) { + // Execute the HLO that will run the `all_reduce` and in place update all + // tensors in one graph. + graph_executor->ApplyEagerSync( + const_cast&>(inputs)); + } else { + // all_reduce_token is to enforce the order of the cc ops. There is no point + // of setting it for eager mode since each cc op will be executed + // independently. + SetAllReduceToken( + inputs.front()->GetDevice(), + std::make_shared(node, inputs.size())); } - SetAllReduceToken(inputs.front()->GetDevice(), - std::make_shared(node, inputs.size())); } std::pair reduce_scatter(