Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support eager mode for multi-process training #7327

Merged
merged 4 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/eager/train_decoder_only_eager_multi_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys
import os
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append(example_folder)
from train_decoder_only_base import TrainDecoderOnlyBase

import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm


class TrainDecoderXLADDP(TrainDecoderOnlyBase):

def run_optimizer(self):
# optimizer_step will call `optimizer.step()` and all_reduce the gradident
xm.optimizer_step(self.optimizer)


def _mp_fn(index):
import torch_xla
torch_xla.experimental.eager_mode(True)
xla_ddp = TrainDecoderXLADDP()
xla_ddp.start_training()


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
2 changes: 1 addition & 1 deletion examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self):
self.config = DecoderOnlyConfig()
self.batch_size = 16
self.seq_len = 512
self.num_steps = 300
self.num_steps = 200
self.num_epochs = 1
self.train_dataset_len = 1200000 # Roughly the size of Imagenet dataset.
# For the purpose of this example, we are going to use fake data.
Expand Down
40 changes: 40 additions & 0 deletions test/eager/test_eager_all_reduce_in_place.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch_xla

import torch_xla.core.xla_model as xm
import torch_xla.debug
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.debug.metrics as met


def _mp_fn(index):
import torch_xla
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()

if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
return

ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device)
ordinal_tensor_2 = torch.tensor([index], dtype=torch.int32).to(device)
xm.wait_device_ops()
met.clear_all()

# all_reduce with list of tensor as input will be a inplace op. This is
# used by the optimizer_step.
xm.all_reduce(xm.REDUCE_SUM, [ordinal_tensor_1, ordinal_tensor_2])

xm.wait_device_ops()
assert met.metric_data("EagerOpExecuteTime")[0] == 1

num_device = torch_xla.runtime.global_runtime_device_count()
expected_sum = (num_device - 1) * num_device / 2
expected_1 = torch.tensor([(expected_sum)], dtype=torch.float)
expected_2 = torch.tensor([(expected_sum)], dtype=torch.int32)
assert torch.allclose(expected_1, ordinal_tensor_1.cpu())
assert torch.allclose(expected_2, ordinal_tensor_2.cpu())


if __name__ == '__main__':
xmp.spawn(_mp_fn, args=())
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ function run_xla_op_tests2 {
run_test "$CDIR/eager/test_eager.py"
run_test "$CDIR/eager/test_eager_with_xla_compile.py"
run_test "$CDIR/eager/test_eager_with_torch_compile.py"
run_test "$CDIR/eager/test_eager_all_reduce_in_place.py"
}

# All the new xla op tests should go to run_xla_op_tests3
Expand Down
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ TPU_VERSION=$(python -c "import sys; sys.path.remove(''); import torch_xla; prin
if [[ -n "$TPU_VERSION" && "$TPU_VERSION" == "4" ]]; then
python3 examples/eager/train_decoder_only_eager.py
python3 examples/eager/train_decoder_only_eager_with_compile.py
python3 examples/eager/train_decoder_only_eager_multi_process.py
fi
5 changes: 3 additions & 2 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ void XLATensor::SetIrValue(torch::lazy::Value ir_value, bool inplace) {
data()->is_cloned = false;
}

void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value,
bool delay_eager_executation) {
auto xla_shape = shape();
if (xla_shape.get().element_type() != GetXlaShape(ir_value).element_type()) {
ir_value =
Expand All @@ -361,7 +362,7 @@ void XLATensor::SetInPlaceIrValue(torch::lazy::Value ir_value) {
XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();

// in place update should also be triggered eagerly if configured
if (graph_executor->UseEagerMode()) {
if (graph_executor->UseEagerMode() && !delay_eager_executation) {
std::vector<XLATensorPtr> xtensors({c10::make_intrusive<XLATensor>(*this)});
graph_executor->ApplyEagerSync(xtensors);
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ class XLATensor : public torch::lazy::LazyTensor {
// TODO(alanwaketan): Reuse the upstream ones once Functionalization is done.
torch::lazy::Value GetIrValue() const;
void SetIrValue(torch::lazy::Value ir_value, bool inplace = true);
void SetInPlaceIrValue(torch::lazy::Value ir_value);
void SetInPlaceIrValue(torch::lazy::Value ir_value,
bool delay_eager_executation = false);

// TODO(alanwaketan): Reuse the upstream one once Functionalization is done.
std::optional<at::Tensor> CurrentTensorData() const;
Expand Down
22 changes: 19 additions & 3 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,26 @@ void all_reduce(const std::vector<XLATensorPtr>& inputs,
reduce_type, input_values, GetAllReduceToken(inputs.front()->GetDevice()),
scale, std::move(groups), pin_layout);
for (size_t i = 0; i < inputs.size(); ++i) {
inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i));
// In eager mode we don't want to execute the IR for each tensor because
// that will execute the `all_reduce` x times.
inputs[i]->SetInPlaceIrValue(torch::lazy::Value(node, i),
/*delay_eager_executation=*/true);
}

XLAGraphExecutor* graph_executor = XLAGraphExecutor::Get();
if (graph_executor->UseEagerMode()) {
// Execute the HLO that will run the `all_reduce` and in place update all
// tensors in one graph.
graph_executor->ApplyEagerSync(
const_cast<std::vector<XLATensorPtr>&>(inputs));
} else {
// all_reduce_token is to enforce the order of the cc ops. There is no point
// of setting it for eager mode since each cc op will be executed
// independently.
SetAllReduceToken(
inputs.front()->GetDevice(),
std::make_shared<torch::lazy::Value>(node, inputs.size()));
}
SetAllReduceToken(inputs.front()->GetDevice(),
std::make_shared<torch::lazy::Value>(node, inputs.size()));
}

std::pair<XLATensorPtr, torch::lazy::Value> reduce_scatter(
Expand Down
Loading