-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support eager mode for multi-process training (#7327)
- Loading branch information
Showing
8 changed files
with
93 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import sys | ||
import os | ||
example_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0]))) | ||
sys.path.append(example_folder) | ||
from train_decoder_only_base import TrainDecoderOnlyBase | ||
|
||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.core.xla_model as xm | ||
|
||
|
||
class TrainDecoderXLADDP(TrainDecoderOnlyBase): | ||
|
||
def run_optimizer(self): | ||
# optimizer_step will call `optimizer.step()` and all_reduce the gradident | ||
xm.optimizer_step(self.optimizer) | ||
|
||
|
||
def _mp_fn(index): | ||
import torch_xla | ||
torch_xla.experimental.eager_mode(True) | ||
xla_ddp = TrainDecoderXLADDP() | ||
xla_ddp.start_training() | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import torch_xla | ||
|
||
import torch_xla.core.xla_model as xm | ||
import torch_xla.debug | ||
import torch_xla.distributed.xla_multiprocessing as xmp | ||
import torch_xla.debug.metrics as met | ||
|
||
|
||
def _mp_fn(index): | ||
import torch_xla | ||
torch_xla.experimental.eager_mode(True) | ||
|
||
device = torch_xla.device() | ||
|
||
if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'): | ||
return | ||
|
||
ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device) | ||
ordinal_tensor_2 = torch.tensor([index], dtype=torch.int32).to(device) | ||
xm.wait_device_ops() | ||
met.clear_all() | ||
|
||
# all_reduce with list of tensor as input will be a inplace op. This is | ||
# used by the optimizer_step. | ||
xm.all_reduce(xm.REDUCE_SUM, [ordinal_tensor_1, ordinal_tensor_2]) | ||
|
||
xm.wait_device_ops() | ||
assert met.metric_data("EagerOpExecuteTime")[0] == 1 | ||
|
||
num_device = torch_xla.runtime.global_runtime_device_count() | ||
expected_sum = (num_device - 1) * num_device / 2 | ||
expected_1 = torch.tensor([(expected_sum)], dtype=torch.float) | ||
expected_2 = torch.tensor([(expected_sum)], dtype=torch.int32) | ||
assert torch.allclose(expected_1, ordinal_tensor_1.cpu()) | ||
assert torch.allclose(expected_2, ordinal_tensor_2.cpu()) | ||
|
||
|
||
if __name__ == '__main__': | ||
xmp.spawn(_mp_fn, args=()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters