Skip to content

Commit

Permalink
More flexible RNG synchronization (#8)
Browse files Browse the repository at this point in the history
* More flexible RNG synchronization

* Address review comment
  • Loading branch information
sgugger authored Mar 9, 2021
1 parent 38c4138 commit e67aa2e
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 36 deletions.
2 changes: 2 additions & 0 deletions docs/source/internal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ Utilities

.. autofunction:: accelerate.utils.set_seed

.. autofunction:: accelerate.utils.synchronize_rng_state

.. autofunction:: accelerate.utils.synchronize_rng_states

.. autofunction:: accelerate.utils.wait_for_everyone
23 changes: 16 additions & 7 deletions docs/source/quicktour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -340,20 +340,29 @@ library handles the sharding of your data between processes by changing that :ob

The :class:`~accelerate.data_loader.DataLoaderShard` subclasses :obj:`DataLoader` to add the following functionality:

- it synchronizes the torch random number generators of all processes at each new iteration, to ensure any
- it synchronizes the appropriate random number generator of all processes at each new iteration, to ensure any
randomization (like shuffling) is done the exact same way across processes.
- it puts the batches on the proper device before yielding them (unless you have opted out of
:obj:`device_placement=True`).

The random number generator synchronization will by default synchronize:

- the :obj:`generator` attribute of a given sampler (like the PyTorch :obj:`RandomSampler`) for PyTorch >= 1.6
- the main random number generator in PyTorch <=1.5.1

You can choose which random number generator(s) to synchronize with the :obj:`rng_types` argument of the main
:class:`~accelerate.Accelerator`. In PyTorch >= 1.6, it is recommended to rely on local :obj:`generator`s to avoid
setting the same seed in the main random number generator in all processes.
.. Warning::
The random number generator synchronization will affect any other potential random artifacts you could have in your
Synchronization the main torch (or CUDA or XLA) random number generator will affect any other potential random artifacts you could have in your
dataset (like random data augmentation) in the sense all processes will get the same random numbers from the torch
random modules (so will apply the same random data augmentation if it's controlled by torch). While this is usually
fine, you should use the random number generator from the Python :obj:`random` module or NumPy for your data
augmentation if you think this will be a problem.
random modules (so will apply the same random data augmentation if it's controlled by torch).
.. Note::
The randomization part of your sampler on the other hand should absolutely be done using the torch random number
generator (like in the traditional :obj:`RandomSampler`).
The randomization part of your custom sampler, batch sampler or iterable dataset should be done using a local
:obj:`torch.Generator` object (in PyTorch >= 1.6), see the traditional :obj:`RandomSampler`, as an example.

See more details about the internal in the :doc:`Internals page <internal>`.
28 changes: 26 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union

import torch

from packaging import version

from .data_loader import prepare_data_loader
from .optimizer import AcceleratedOptimizer
from .state import AcceleratorState, DistributedType
from .utils import extract_model_from_parallel, gather, save, wait_for_everyone
from .utils import RNGType, extract_model_from_parallel, gather, save, wait_for_everyone


class Accelerator:
Expand All @@ -42,6 +44,18 @@ class Accelerator:
cpu (:obj:`bool`, `optional`):
Whether or not to force the script to execute on CPU. Will ignore GPU available if set to :obj:`True` and
force the execution on one process only.
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
The list of random number generators to synchronize at the beginning of each iteration in your prepared
dataloaders. Should be one or several of:
- :obj:`"torch"`: the base torch random number generator
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
- :obj:`"xla"`: the XLA random number generator (TPU only)
- :obj:`"generator"`: the :obj:`torch.Generator` of the sampler (or batch sampler if there is no sampler
in your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
Will default to :obj:`["torch"]` for PyTorch versions <=1.5.1 and :obj:`["generator"]` for PyTorch versions
>= 1.6.
Attributes
Expand All @@ -50,7 +64,12 @@ class Accelerator:
"""

def __init__(
self, device_placement: bool = True, split_batches: bool = False, fp16: bool = None, cpu: bool = False
self,
device_placement: bool = True,
split_batches: bool = False,
fp16: bool = None,
cpu: bool = False,
rng_types: Optional[List[Union[str, RNGType]]] = None,
):
self.state = AcceleratorState(fp16=fp16, cpu=cpu, _from_accelerator=True)

Expand All @@ -67,6 +86,10 @@ def __init__(
# Internal references to the training objects
self._optimizers = []

# RNG Types
if rng_types is None:
self.rng_types = ["torch"] if version.parse(torch.__version__) <= version.parse("1.5.1") else ["generator"]

@property
def distributed_type(self):
return self.state.distributed_type
Expand Down Expand Up @@ -187,6 +210,7 @@ def prepare_data_loader(self, data_loader):
process_index=self.process_index,
split_batches=self.split_batches,
put_on_device=self.device_placement,
rng_types=self.rng_types,
)

def prepare_optimizer(self, optimizer):
Expand Down
47 changes: 43 additions & 4 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import List, Optional, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset

from packaging import version

from .state import AcceleratorState, DistributedType, is_tpu_available
from .utils import send_to_device, synchronize_rng_states
from .utils import RNGType, send_to_device, synchronize_rng_states


if is_tpu_available():
Expand Down Expand Up @@ -262,16 +262,29 @@ class DataLoaderShard(DataLoader):
The dataset to use to build this datalaoder.
device (:obj:`torch.device`, `optional`):
If passed, the device to put all batches on.
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
several of:
- :obj:`"torch"`: the base torch random number generator
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
- :obj:`"xla"`: the XLA random number generator (TPU only)
- :obj:`"generator"`: an optional :obj:`torch.Generator`
generator (:obj:`torch.Generator`, `optional`):
A random number generator to keep synchronized accross processes.
kwargs:
All other keyword arguments to pass to the regular :obj:`DataLoader` initialization.
"""

def __init__(self, dataset, device=None, **kwargs):
def __init__(self, dataset, device=None, rng_types=None, generator=None, **kwargs):
super().__init__(dataset, **kwargs)
self.device = device
self.rng_types = rng_types
self.generator = generator

def __iter__(self):
synchronize_rng_states()
if self.rng_types is not None:
synchronize_rng_states(self.rng_types, self.generator)
state = AcceleratorState()
for batch in super().__iter__():
if state.distributed_type == DistributedType.TPU:
Expand All @@ -286,6 +299,7 @@ def prepare_data_loader(
process_index: Optional[int] = None,
split_batches: bool = False,
put_on_device: bool = False,
rng_types: Optional[List[Union[str, RNGType]]] = None,
) -> DataLoader:
"""
Wraps a PyTorch :obj:`DataLoader` to generate batches for one of the processes only.
Expand Down Expand Up @@ -318,6 +332,15 @@ def prepare_data_loader(
put_on_device (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to put the batches on :obj:`device` (only works if the batches are nested list, tuples or
dictionaries of tensors).
rng_types (list of :obj:`str` or :class:`~accelerate.utils.RNGType`):
The list of random number generators to synchronize at the beginning of each iteration. Should be one or
several of:
- :obj:`"torch"`: the base torch random number generator
- :obj:`"cuda"`: the CUDA random number generator (GPU only)
- :obj:`"xla"`: the XLA random number generator (TPU only)
- :obj:`"generator"`: the :obj:`torch.Generator` of the sampler (or batch sampler if there is no sampler
in your dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
Returns:
:obj:`torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches
Expand All @@ -342,9 +365,12 @@ def prepare_data_loader(

new_dataset = dataloader.dataset
new_batch_sampler = dataloader.batch_sampler
generator = getattr(dataloader, "generator", None)
# No change if no multiprocess
if num_processes != 1:
if isinstance(new_dataset, IterableDataset):
if getattr(dataloader.dataset, "generator", None) is not None:
generator = dataloader.dataset.generator
new_dataset = IterableDatasetShard(
new_dataset,
batch_size=dataloader.batch_size,
Expand All @@ -355,6 +381,13 @@ def prepare_data_loader(
)
else:
# New batch sampler for the current process.
if hasattr(dataloader.sampler, "generator"):
if dataloader.sampler.generator is None:
dataloader.sampler.generator = torch.Generator()
generator = dataloader.sampler.generator
generator.manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
elif getattr(dataloader.batch_sampler, "generator", None) is not None:
generator = dataloader.batch_sampler.generator
new_batch_sampler = BatchSamplerShard(
dataloader.batch_sampler,
num_processes=num_processes,
Expand All @@ -369,8 +402,12 @@ def prepare_data_loader(
"sampler",
"batch_sampler",
"drop_last",
"generator",
]

if rng_types is not None and generator is None and "generator" in rng_types:
rng_types.remove("generator")

kwargs = {
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k])
for k in _PYTORCH_DATALOADER_KWARGS
Expand All @@ -380,5 +417,7 @@ def prepare_data_loader(
new_dataset,
device=device if put_on_device else None,
batch_sampler=new_batch_sampler,
rng_types=rng_types,
generator=generator,
**kwargs,
)
33 changes: 23 additions & 10 deletions src/accelerate/test_utils/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from packaging import version

import torch
from torch.utils.data import DataLoader

Expand All @@ -34,10 +36,16 @@ def init_state_check():

def rng_sync_check():
state = AcceleratorState()
synchronize_rng_states()
synchronize_rng_states(["torch"])
assert are_the_same_tensors(torch.get_rng_state())
if state.distributed_type == DistributedType.MULTI_GPU:
synchronize_rng_states(["cuda"])
assert are_the_same_tensors(torch.cuda.get_rng_state())
if version.parse(torch.__version__) >= version.parse("1.6.0"):
generator = torch.Generator()
synchronize_rng_states(["generator"], generator=generator)
assert are_the_same_tensors(generator.get_state())

if state.local_process_index == 0:
print("All rng are properly synched.")

Expand Down Expand Up @@ -101,13 +109,14 @@ def dl_preparation_check():
print("Shuffled dataloader passing.")


def mock_training(length, batch_size):
def mock_training(length, batch_size, generator):
set_seed(42)
generator.manual_seed(42)
train_set = RegressionDataset(length=length)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
for _ in range(3):
for epoch in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
Expand All @@ -119,21 +128,23 @@ def mock_training(length, batch_size):

def training_check():
state = AcceleratorState()
generator = torch.Generator()
batch_size = 8
length = batch_size * 4 * state.num_processes

train_set, old_model = mock_training(length, batch_size * state.num_processes)
train_set, old_model = mock_training(length, batch_size * state.num_processes, generator)
assert are_the_same_tensors(old_model.a)
assert are_the_same_tensors(old_model.b)

accelerator = Accelerator()
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
for _ in range(3):
generator.manual_seed(42)
for epoch in range(3):
for batch in train_dl:
model.zero_grad()
output = model(batch["x"])
Expand All @@ -145,15 +156,16 @@ def training_check():
assert torch.allclose(old_model.a, model.a)
assert torch.allclose(old_model.b, model.b)

accelerator.print("Training yielded the same results on one CPU or distributes setup with no batch split.")
accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.")

accelerator = Accelerator(split_batches=True)
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True)
train_dl = DataLoader(train_set, batch_size=batch_size * state.num_processes, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
Expand All @@ -170,12 +182,13 @@ def training_check():

# Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16
accelerator = Accelerator(fp16=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True)
train_dl = DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator)
model = RegressionModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer)
set_seed(42)
generator.manual_seed(42)
for _ in range(3):
for batch in train_dl:
model.zero_grad()
Expand Down
Loading

0 comments on commit e67aa2e

Please sign in to comment.