Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Perform reduction for dict in training_step and DP #6324

Merged
merged 14 commits into from
Mar 4, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
10 changes: 8 additions & 2 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection


class DataParallelPlugin(ParallelPlugin):
Expand Down Expand Up @@ -46,8 +47,13 @@ def reduce(self, tensor, *args, **kwargs):
if isinstance(tensor, Result):
tensor.dp_reduce()

elif isinstance(tensor, torch.Tensor):
tensor = tensor.mean()
else:

def _reduce(tensor: torch.Tensor):
dtype_tensor = tensor.dtype
return tensor.float().mean().type(dtype_tensor)

tensor = apply_to_collection(tensor, torch.Tensor, _reduce)

return tensor

Expand Down
19 changes: 19 additions & 0 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,22 @@ def test_dp_test(tmpdir):
new_weights = model.layer_0.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@RunIf(min_gpus=2)
def test_dp_training_step_dict(tmpdir):
"""
This test verify dp properly reduce dictionaries
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test does not do what it says. It just makes sure that nothing crashes, but there are no assertions that the behavior is correct.
I bet if we change mean to something else, this test will not fail.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

besides, nothing gets reduced here anyway because BoringModel uses batch size 1 for all dataloaders and so this never runs on more than 1 gpu

"""

model = BoringModel()
model.training_step_end = None
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=0,
gpus=2,
accelerator='dp',
)
trainer.fit(model)
4 changes: 2 additions & 2 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import OrderedDict
from logging import INFO

Expand All @@ -22,7 +21,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -274,6 +273,7 @@ def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
seed_everything(0)

class TestPruning(ModelPruning):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
Expand Down