From 679796434a294f039de5e91603b078d07cf0feb3 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Tue, 9 Feb 2021 12:54:38 -0500 Subject: [PATCH] Leave buffers on self.compute_device (#67) --- .../shard_params_data_parallel.py | 27 ++++++++++++----- .../test_shard_params_data_parallel.py | 30 ++++++++++++------- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/fairscale/nn/data_parallel/shard_params_data_parallel.py b/fairscale/nn/data_parallel/shard_params_data_parallel.py index 9676db895..d73e1399d 100644 --- a/fairscale/nn/data_parallel/shard_params_data_parallel.py +++ b/fairscale/nn/data_parallel/shard_params_data_parallel.py @@ -125,16 +125,18 @@ def __init__( # Shard module parameters in place self._shard_parameters_() - if self.mixed_precision: - # Cast all module buffers to FP16 (buffers are not sharded). - self.apply(cast_buffers_to_fp16) - # Make sure all parameters are sharded. for n, p in self.named_parameters(): assert getattr(p, "_is_sharded", False), f"found unsharded parameter: {n} ; {p.size()}" self._reset_lazy_init() + @torch.no_grad() + def _all_buffers_to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + """Move all buffers to the specified device and dtype, recursively.""" + cast_fn = functools.partial(cast_buffers_, device=device, dtype=dtype) + self.apply(cast_fn) + @torch.no_grad() def _shard_parameters_(self) -> None: """ @@ -217,17 +219,19 @@ def state_dict(self, *args, **kwargs): # type: ignore """ Returns the whole (unsharded) state of the module. Parameters are not sharded, so the resulting state_dict can be loaded directly by the - wrapped Module without any sharding-specific logic. + wrapped Module without any sharding-specific logic. Returned tensors will always be typed float32 """ torch.cuda.synchronize() self._lazy_init() self._rebuild_full_params() + self._all_buffers_to(dtype=torch.float32) # Buffers dtype stays consistent with parameters. state_dict = self.module.state_dict(*args, **kwargs) # We don't free the params after generating the state dict, since # freeing is done in-place (via the Storage) and would corrupt the # returned state dict. However, we need to maintain the invariant that # p.data corresponds to the FP32 param shard, so we do that here. self._use_fp32_param_shard() + self._all_buffers_to(dtype=self.compute_dtype) return state_dict # TODO (Min): figuring out how to do typing for this overloaded function. @@ -278,6 +282,10 @@ def _lazy_init(self) -> None: if self._is_root is None: self._set_is_root() self._setup_streams() + if self.cpu_offload: # Buffers stay on GPU, and dont get sharded + self._all_buffers_to(device=torch.device("cuda"), dtype=self.compute_dtype) + else: + self._all_buffers_to(dtype=self.compute_dtype) # Don't free the full params for the outer-most (root) instance, since # those params will be needed immediately after for the backward pass. @@ -652,11 +660,14 @@ def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: return args, kwargs -def cast_buffers_to_fp16(module: nn.Module) -> None: - """Cast buffers of a module to FP16.""" +def cast_buffers_( + module: nn.Module, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None +) -> None: + """Cast all of module.named_buffers to device, dtype.""" + # if buffers are already on the right device and/or dtype this is just python loop cost for key, buf in module.named_buffers(recurse=False): if buf is not None: - setattr(module, key, buf.half()) + setattr(module, key, buf.to(dtype=dtype, device=device)) def free_storage_(data: torch.Tensor) -> None: diff --git a/tests/nn/data_parallel/test_shard_params_data_parallel.py b/tests/nn/data_parallel/test_shard_params_data_parallel.py index 88d78f0de..94956f7f7 100644 --- a/tests/nn/data_parallel/test_shard_params_data_parallel.py +++ b/tests/nn/data_parallel/test_shard_params_data_parallel.py @@ -27,6 +27,8 @@ # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 +_BUFFER_NAME = "vocab_bias" + class DistributedTest(unittest.TestCase): def setUp(self): @@ -42,9 +44,9 @@ def setUp(self): raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") @staticmethod - def _train_for_several_steps(model, num_steps, autocast): + def _train_for_several_steps(model, num_steps, autocast, lr=0.01): model_device = next(model.parameters()).device - optim = torch.optim.Adam(model.parameters(), lr=0.01) + optim = torch.optim.Adam(model.parameters(), lr=lr) for _ in range(num_steps): optim.zero_grad() with torch.cuda.amp.autocast(enabled=autocast): @@ -178,7 +180,11 @@ def test_cpu_offload_and_cpu_grads(self): # We don't test the False condition because that requires the optimizer to internally do # the device transfer and PyTorch optimizers don't support this. config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} - test_fn = functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False) + test_fn = functools.partial( + self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.001 + ) + # We use lower lr to reduce this test's sensitivity to slightly different CPU vs CUDA behavior of pytorch. + # With lr=0.01, it fails on torch 1.6.0. spawn_and_init(test_fn) def test_cpu_offload_and_cuda_grads_breaks(self): @@ -210,7 +216,7 @@ def test_delayed_reduce_scatter(self): spawn_and_init(test_fn) @classmethod - def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True): + def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3, use_cuda=True, lr=0.01): if config["mixed_precision"]: autocast = True # Force the compute dtype to be torch.float32 so that we get @@ -224,7 +230,7 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, process_group=group) - ref_loss = cls._train_for_several_steps(model, num_steps, autocast) + ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) ref_state_dict = model.module.state_dict() # Confirm we get the same behavior using ShardParamsDataParallel. @@ -233,14 +239,14 @@ def _test_identical_outputs(cls, model_init_fn, config, rank, group, num_steps=3 model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") - shard_loss = cls._train_for_several_steps(model, num_steps, autocast) + shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr) shard_state_dict = model.state_dict() try: torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: - raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}" "\n\n{e}") + raise Exception(f"ShardParamsDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") class TestParamInit(DistributedTest): @@ -332,7 +338,7 @@ def test_local_state_dict_odd_vocab_shape_breaks(self): spawn_and_init(test_fn) @classmethod - def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): + def _load_local_and_train(self, config, rank, group, d_model=16, d_vocab=16): """Check that local_state_dict can be saved and loaded for a given worker, and that training updates it""" model = ShardParamsDataParallel( TransformerWithSharedParams(d_model=d_model, d_vocab=d_vocab), group, **config @@ -346,11 +352,11 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): state_1_weight = state_1[weight_key] assert state_1_weight.dtype == torch.float32, f"got dtype {state_1_weight.dtype} expected torch.float32" if not model.flatten_parameters: - # This weight will be sharded since we access module.state_dict directly + # The weight will be sharded since we access module.state_dict directly state_1_module_weight = model.module.state_dict()[weight_key] torch.testing.assert_allclose(state_1_weight, state_1_module_weight) torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) - self._train_for_several_steps(model, 4, model.mixed_precision) + self._train_for_several_steps(model, 1, model.mixed_precision) state_2 = model.local_state_dict() state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} @@ -361,7 +367,7 @@ def _load_local_and_train(self, config, rank, group, d_model=32, d_vocab=32): # Assert that parameters were updated since before training unchanged = [] for k in state_1: - if (state_before_training[k] == state_after_training[k]).all(): + if (state_before_training[k] == state_after_training[k]).all() and (_BUFFER_NAME not in k): unchanged.append(k) if unchanged: raise AssertionError(f"params {unchanged} not changed after training") @@ -520,6 +526,7 @@ def __init__(self, *unused_args, d_vocab=32, d_model=16, **unused_kwargs): self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights self.output_proj.weight = self.embed_tokens.weight + self.register_buffer(_BUFFER_NAME, self.embed_tokens.weight.new_ones((d_model,))) def get_input(self, device): torch.manual_seed(1) # keep everything deterministic @@ -529,6 +536,7 @@ def get_input(self, device): def forward(self, src_ids, tgt_ids): src = self.embed_tokens(src_ids) + src = src + self.vocab_bias tgt = self.embed_tokens(tgt_ids) x = self.transformer(src, tgt) return self.output_proj(x)