Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/SWA_scheduler_step
Browse files Browse the repository at this point in the history
  • Loading branch information
s-rog authored Apr 6, 2021
2 parents 1581527 + 7f91c5e commit e7f05a6
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 51 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Sanitize `None` params during pruning ([#6836](https://github.com/PyTorchLightning/pytorch-lightning/pull/6836))


- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))


Expand Down Expand Up @@ -200,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))


- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))


## [1.2.6] - 2021-03-30

### Changed
Expand Down
3 changes: 1 addition & 2 deletions docs/source/advanced/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ To get a TPU on colab, follow these steps:

.. code-block::
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
5. Once the above is done, install PyTorch Lightning (v 0.7.0+).

Expand Down
24 changes: 0 additions & 24 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -912,30 +912,6 @@ use_amp
~~~~~~~
True if using Automatic Mixed Precision (AMP)

------------

use_ddp
~~~~~~~
True if using ddp

------------

use_ddp2
~~~~~~~~
True if using ddp2

------------

use_dp
~~~~~~
True if using dp

------------

use_tpu
~~~~~~~
True if using TPUs

--------------

automatic_optimization
Expand Down
4 changes: 1 addition & 3 deletions docs/source/starter/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,7 @@ Next, install the required xla library (adds support for PyTorch on TPUs)

.. code-block:: shell
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
In distributed training (multiple GPUs and multiple TPU cores) each GPU or TPU core will run a copy
of this program. This means that without taking any care you will download the dataset N times which
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
)
self.setup_precision_plugin(plugin)

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath) -> None:
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
# When `current_epoch` is 10, feature_extractor will start training.
if current_epoch == self._unfreeze_at_epoch:
self.unfreeze_and_add_param_group(
module=pl_module.feature_extractor,
modules=pl_module.feature_extractor,
optimizer=optimizer,
train_bn=True,
)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ def sanitize_parameters_to_prune(
current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)]

if parameters_to_prune is None:
parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)]
parameters_to_prune = [
(m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
]
elif (
isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0
and all(len(p) == 2 for p in parameters_to_prune)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, TYPE_CHECKING, TypeVar, Union

import torch
from torch.nn import Module
Expand All @@ -30,6 +30,8 @@
if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer

TBroadcast = TypeVar("T")


class TrainingTypePlugin(Plugin, ABC):
"""A Plugin to change the behaviour of the training, validation and test-loop."""
Expand Down Expand Up @@ -88,7 +90,7 @@ def barrier(self, name: Optional[str] = None) -> None:
"""Forces all possibly joined processes to wait for each other"""

@abstractmethod
def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes"""

@abstractmethod
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def validate(
self.validating = True

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
if val_dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
)
Expand Down Expand Up @@ -928,7 +928,7 @@ def test(
self.testing = True

# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
if test_dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`')

model_provided = model is not None
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def predict(
self.state = TrainerState.PREDICTING
self.predicting = True

if dataloaders and datamodule:
if dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
Expand Down
19 changes: 11 additions & 8 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self):
self.layer = Sequential(
OrderedDict([
("mlp_1", nn.Linear(32, 32)),
("mlp_2", nn.Linear(32, 32)),
("mlp_2", nn.Linear(32, 32, bias=False)),
("mlp_3", nn.Linear(32, 2)),
])
)
Expand Down Expand Up @@ -85,7 +85,10 @@ def train_with_pruning_callback(
if parameters_to_prune:
pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")]
else:
pruning_kwargs["parameter_names"] = ["weight"]
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
pruning_kwargs["parameter_names"] = ["weight"]
else:
pruning_kwargs["parameter_names"] = ["weight", "bias"]
if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"):
pruning_kwargs["pruning_dim"] = 0
if pruning_fn == "ln_structured":
Expand Down Expand Up @@ -249,14 +252,14 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 500 (48.83%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 44 (68.75%)", # noqa: E501
"Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)",
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 506 (49.41%) -> 633 (61.82%)", # noqa: E501
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 38 (59.38%) -> 47 (73.44%)", # noqa: E501
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 500 (48.83%) -> 635 (62.01%)", # noqa: E501
"Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 44 (68.75%) -> 45 (70.31%)", # noqa: E501
"Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 635 (62.01%) -> 830 (81.05%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 45 (70.31%) -> 54 (84.38%)", # noqa: E501
]

filepath = str(tmpdir / "foo.ckpt")
Expand Down
26 changes: 20 additions & 6 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,28 +636,42 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):

def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
model = BoringModel()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):
class IterableWithoutLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

class IterableWithLen(IterableWithoutLen):

def __len__(self):
return len(original_dataset)

# with __len__ defined
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.validate(model, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.predict(model, dataloaders=[dataloader])

# without __len__ defined
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
assert not has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)


@RunIf(min_gpus=2)
Expand Down

0 comments on commit e7f05a6

Please sign in to comment.