From 48327b7e610dab66fe23df7f1ac7dd2110fc0dc9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 4 Feb 2021 14:33:39 +0100 Subject: [PATCH 01/27] Remove pruning check because it was added in 1.4.0 and that is our minimal torch version --- pytorch_lightning/callbacks/pruning.py | 6 +----- pytorch_lightning/utilities/__init__.py | 1 - pytorch_lightning/utilities/imports.py | 1 - tests/callbacks/test_pruning.py | 12 ++---------- 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index c008296d82fba..3bb6e19ab62a1 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -22,19 +22,15 @@ from functools import partial from typing import Callable, List, Optional, Tuple, Union +import torch.nn.utils.prune as pytorch_prune from torch import nn from torch.nn.modules.container import ModuleDict, ModuleList import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _PYTORCH_PRUNE_AVAILABLE: - import torch.nn.utils.prune as pytorch_prune - - _PYTORCH_PRUNING_FUNCTIONS = { "ln_structured": pytorch_prune.ln_structured, "l1_unstructured": pytorch_prune.l1_unstructured, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index a8f3e134936ff..bf6069230f115 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -35,7 +35,6 @@ _module_available, _NATIVE_AMP_AVAILABLE, _OMEGACONF_AVAILABLE, - _PYTORCH_PRUNE_AVAILABLE, _RPC_AVAILABLE, _TORCHTEXT_AVAILABLE, _XLA_AVAILABLE, diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index c0a4d15411dc4..7a65d32cb3ff1 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -54,4 +54,3 @@ def _module_available(module_path: str) -> bool: _GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group') _FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0") _BOLTS_AVAILABLE = _module_available('pl_bolts') -_PYTORCH_PRUNE_AVAILABLE = _module_available('torch.nn.utils.prune') diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 37c8fb464714f..46bf71f517fae 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -17,18 +17,14 @@ import numpy as np import pytest import torch +import torch.nn.utils.prune as pytorch_prune from torch import nn from pytorch_lightning import Trainer -from pytorch_lightning.utilities import _PYTORCH_PRUNE_AVAILABLE +from pytorch_lightning.callbacks import ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel -if _PYTORCH_PRUNE_AVAILABLE: - import torch.nn.utils.prune as pytorch_prune - - from pytorch_lightning.callbacks import ModelPruning - class PruningModel(BoringModel): @@ -198,7 +194,6 @@ def test_with_pruning_callback_misconfiguration(tmpdir): _ = ModelPruning(**model_pruning_args) -@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ") @pytest.mark.parametrize("parameters_to_prune", [False, True]) @pytest.mark.parametrize("use_global_unstructured", [False, True]) @pytest.mark.parametrize("use_custom_pruning_fn", [False, True]) @@ -208,7 +203,6 @@ def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, accelerator=None, gpus=None, num_processes=1, use_custom_pruning_fn=use_custom_pruning_fn) -@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ") @pytest.mark.parametrize("parameters_to_prune", [False, True]) @pytest.mark.parametrize("use_global_unstructured", [False, True]) @pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', @@ -219,14 +213,12 @@ def test_pruning_callback_ddp(tmpdir, use_global_unstructured, parameters_to_pru accelerator="ddp", gpus=2, num_processes=0) -@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_spawn(tmpdir): train_with_pruning_callback(tmpdir, False, True, accelerator="ddp_spawn", gpus=2, num_processes=None) -@pytest.mark.skipif(not _PYTORCH_PRUNE_AVAILABLE, reason="PyTorch prung is needed for this test. ") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, True, False, accelerator="ddp_cpu", gpus=None, num_processes=2) From 91a111f0590b8f7e4ab1277f5da579d63003f275 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 04:38:01 +0100 Subject: [PATCH 02/27] Fixing many bugs --- pytorch_lightning/callbacks/pruning.py | 187 ++++++++------------ tests/callbacks/test_pruning.py | 229 +++++++++---------------- 2 files changed, 153 insertions(+), 263 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 3bb6e19ab62a1..f95fab3b53071 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -26,7 +26,6 @@ from torch import nn from torch.nn.modules.container import ModuleDict, ModuleList -import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -47,18 +46,17 @@ class ModelPruning(Callback): - PARAMETER_NAMES = ("weight", "bias") def __init__( self, - pruning_fn: Union[Callable, str] = None, + pruning_fn: Optional[Union[Callable, str]] = None, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None, - parameter_names: List[str] = ["weight"], + parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, - amount: Optional[Union[int, float]] = 0.5, - make_pruning_permanent: Optional[bool] = True, - use_lottery_ticket_hypothesis: Optional[bool] = True, + amount: Union[int, float, Callable] = 0.5, + make_pruning_permanent: bool = True, + use_lottery_ticket_hypothesis: bool = True, pruning_dim: Optional[int] = None, pruning_norm: Optional[int] = None, ) -> None: @@ -110,15 +108,14 @@ def __init__( amount: quantity of parameters to prune: - - float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. - - int, it represents the absolute number of parameters to prune. - - Callable, the function will be called on every epoch. + - float, between 0.0 and 1.0. represents the fraction of parameters to prune. + - int, represents the absolute number of parameters to prune. + - Callable, for dynamic values. will be called every epoch. - make_pruning_permanent: if True then all - reparametrization pre-hooks and tensors with mask - will be removed on fit end. + make_pruning_permanent: if True then all reparametrization pre-hooks and tensors + with mask will be removed on fit end. - use_lottery_ticket_hypothesis: Wether to use algorithm describes in + use_lottery_ticket_hypothesis: Whether to use algorithm describes in "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) pruning_dim: if you are using structured pruning method you need @@ -138,17 +135,18 @@ def __init__( for param_name in self._parameter_names: if param_name not in self.PARAMETER_NAMES: raise MisconfigurationException( - f"The provided parameter_names {param_name} isn't in {self.PARAMETER_NAMES} " + f"The provided `parameter_names`: {param_name} isn't in {self.PARAMETER_NAMES}" ) if isinstance(pruning_fn, str): + pruning_kwargs = {} pruning_fn = pruning_fn.lower() if pruning_fn not in _PYTORCH_PRUNING_FUNCTIONS: raise MisconfigurationException( - f"The provided pruning_fn {pruning_fn} isn't available with " - f"PyTorch build-in {_PYTORCH_PRUNING_FUNCTIONS.keys()} " + f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's" + f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} " ) - if "unstructured" not in pruning_fn: + if pruning_fn.endswith("_structured"): if pruning_dim is None: raise MisconfigurationException( "When requesting `structured` pruning, the `pruning_dim` should be provided." @@ -158,33 +156,30 @@ def __init__( raise MisconfigurationException( "When requesting `ln_structured` pruning, the `pruning_norm` should be provided." ) - - pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim, n=pruning_norm) - else: - pruning_fn = self._create_pruning_fn(pruning_fn, dim=pruning_dim) - else: - pruning_fn = self._create_pruning_fn(pruning_fn) + pruning_kwargs["n"] = pruning_norm + pruning_kwargs["dim"] = pruning_dim + pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) else: - bases = getattr(pruning_fn, "__bases__", None) - if bases is None or bases[0] != pytorch_prune.BasePruningMethod: + if not isinstance(pruning_fn, pytorch_prune.BasePruningMethod): raise MisconfigurationException( - f'pruning_fn is expected to be the str in {_PYTORCH_PRUNING_FUNCTIONS.keys()} ' - f'or a `PyTorch BasePruningMethod`. Found: {pruning_fn}' + f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" + f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}" ) if not use_global_unstructured: + # TODO: currently not supported raise MisconfigurationException( - '`PyTorch BasePruningMethod` is currently support only for `use_global_unstructured=True`. ') + "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." + ) if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": raise MisconfigurationException( - 'Only "unstructured" PRUNING_TYPE supported for ' - f"the `pruning_method`. Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " + 'Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.' + f" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. " ) self.pruning_fn = pruning_fn - self.make_pruning_permanent = make_pruning_permanent if not isinstance(amount, (int, float, Callable)): @@ -194,9 +189,9 @@ def __init__( self.amount = amount - def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]]): + def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None): """ - This function can be overriden to control which module to prune. + This function can be overridden to control which module to prune. """ return parameters_to_prune @@ -212,24 +207,18 @@ def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs): pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] self._global_kwargs = kwargs return pruning_fn - else: - return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs) + return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], *args, **kwargs) @staticmethod def _wrap_pruning_fn(pruning_fn, *args, **kwargs): - return partial(pruning_fn, **kwargs) + return partial(pruning_fn, *args, **kwargs) def _make_pruning_permanent(self): for module, param_name in self._parameters_to_prune: pytorch_prune.remove(module, param_name) def _resolve_amount(self, current_epoch: int) -> float: - if isinstance(self.amount, Callable): - amount_fn = self.amount - amount = amount_fn(current_epoch) - else: - amount = self.amount - return amount + return self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): """ @@ -269,16 +258,14 @@ def _resolve_global_kwargs(self, amount: float): def _apply_global_pruning(self, amount: float): pytorch_prune.global_unstructured( - self._parameters_to_prune, - pruning_method=self.pruning_fn, - **self._resolve_global_kwargs(amount) + self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) - def apply_pruning(self, trainer: 'pl.Trainer', pl_module: LightningModule): - amount = self._resolve_amount(trainer.current_epoch) + def apply_pruning(self, current_epoch: int): + amount = self._resolve_amount(current_epoch) # the user could control the pruning frequency with amount_fn - if amount == 0 or amount is None: + if not amount: return if self._use_global_unstructured: @@ -291,96 +278,62 @@ def apply_pruning(self, trainer: 'pl.Trainer', pl_module: LightningModule): def on_before_accelerator_backend_setup(self, trainer, pl_module): parameters_to_prune = self.sanitize_parameters_to_prune( - pl_module, self._parameters_to_prune, parameters=self._parameter_names) + pl_module, self._parameters_to_prune, parameters=self._parameter_names + ) self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune) if self._use_lottery_ticket_hypothesis: - # make a copy of copy of orginal weights. + # make a copy of copy of original weights. self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune] def on_epoch_end(self, trainer, pl_module): - self.apply_pruning(trainer, pl_module) + self.apply_pruning(trainer.current_epoch) if self.make_pruning_permanent: self._make_pruning_permanent() - @staticmethod - def _sanitize_parameters_to_prune(p): - """ - Check the provide element is a pair with: - * nn.Module - * str - - Example:: - - parameters_to_prune = [ - (model.mlp_1, "weight"), - (model.mlp_2, "weight") - ] - """ - return len(p) == 2 and isinstance(p[0], nn.Module) and isinstance(p[1], str) - @staticmethod def sanitize_parameters_to_prune( pl_module: LightningModule, - parameters_to_prune: Optional[List[Tuple[nn.Module, str]]], - parameters: List[str] = ["weight"] - ) -> List: + parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None, + parameters: Optional[List[str]] = None, + ) -> List[Tuple[nn.Module, str]]: """ - This function is responsible to check provided `parameters_to_prune` and `parameters`. + This function is responsible to check provided ``parameters_to_prune` and `parameters`. If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model. """ + parameters = parameters or ModelPruning.PARAMETER_NAMES - is_parameters_to_prune_none = parameters_to_prune is None current_modules = [ - m for m in pl_module.modules() - if not isinstance(m, (LightningModule, ModuleDict, ModuleList)) + m for m in pl_module.modules() if not isinstance(m, (LightningModule, ModuleDict, ModuleList)) ] - if is_parameters_to_prune_none: - parameters_to_prune = [] - for p in parameters: - for m in current_modules: - param = getattr(m, p, None) - if param is not None: - parameters_to_prune.append((m, p)) - - if isinstance(parameters_to_prune, (tuple, list)) \ - and len(parameters_to_prune) > 0 and not is_parameters_to_prune_none: - - if all( - isinstance(p, (list, tuple)) and ModelPruning._sanitize_parameters_to_prune(p) - for p in parameters_to_prune - ): - - missing_modules = [] - missing_parameters = [] - - for module, param_name in parameters_to_prune: - if module not in current_modules: - missing_modules.append(module) - continue - - parameter = getattr(module, param_name) - - if parameter is None: - missing_parameters.append(parameter) - - if len(missing_modules) > 0 or len(missing_parameters) > 0: - raise MisconfigurationException( - "Ths provided parameters_to_tune doesn't exist in the model." - f" Found mismatching modules: {missing_modules} and missing_parameters: {missing_parameters}" - ) - - else: + if parameters_to_prune is None: + parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)] + elif ( + isinstance(parameters_to_prune, (list, tuple)) + and len(parameters_to_prune) > 0 + and all(len(p) == 2 for p in parameters_to_prune) + and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune) + ): + missing_modules, missing_parameters = [], [] + for module, param_name in parameters_to_prune: + if module not in current_modules: + missing_modules.append(module) + continue + if not hasattr(module, param_name): + missing_parameters.append(param_name) + + if len(missing_modules) > 0 or len(missing_parameters) > 0: raise MisconfigurationException( - "The provided parameters_to_prune should either be list of tuple " - "with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None") + "Some provided `parameters_to_tune` don't exist in the model." + f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}" + ) else: - if not isinstance(parameters_to_prune, (list, tuple)): - raise MisconfigurationException( - "The provided parameters_to_prune should either be list of tuple " - "with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None") + raise MisconfigurationException( + "The provided `parameters_to_prune` should either be list of tuple " + "with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None" + ) return parameters_to_prune diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 46bf71f517fae..c0706ce307ad0 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -13,12 +13,13 @@ # limitations under the License. import os import platform +from collections import OrderedDict -import numpy as np import pytest import torch import torch.nn.utils.prune as pytorch_prune from torch import nn +from torch.nn import Sequential from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelPruning @@ -26,199 +27,135 @@ from tests.base import BoringModel -class PruningModel(BoringModel): +class TestModel(BoringModel): + validation_step = None + test_step = None def __init__(self): super().__init__() - self.layer = nn.ModuleDict() + self.layer = Sequential(OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32)), + ("mlp_3", nn.Linear(32, 2)), + ])) - self.layer["mlp_1"] = nn.Linear(32, 32) - self.layer["mlp_2"] = nn.Linear(32, 32) - self.layer["mlp_3"] = nn.Linear(32, 32) - self.layer["mlp_4"] = nn.Linear(32, 32) - self.layer["mlp_5"] = nn.Linear(32, 2) - def forward(self, x): - m = self.layer - x = m["mlp_1"](x) - x = m["mlp_2"](x) - x = m["mlp_3"](x) - x = m["mlp_4"](x) - return m["mlp_5"](x) +class TestPruningMethod(pytorch_prune.BasePruningMethod): + PRUNING_TYPE = "unstructured" - def training_step(self, batch, batch_idx): - output = self.forward(batch) - loss = self.loss(batch, output) - return {"loss": loss} + def compute_mask(self, _, default_mask): + mask = default_mask.clone() + # Prune every other entry in a tensor + mask.view(-1)[::2] = 0 + return mask + + @classmethod + def apply(cls, module, name, amount): + return super(TestPruningMethod, cls).apply(module, name, amount=amount) def train_with_pruning_callback( tmpdir, - parameters_to_prune, - use_global_unstructured, + parameters_to_prune=False, + use_global_unstructured=False, + pruning_fn="l1_unstructured", accelerator=None, gpus=None, - num_processes=None, - use_custom_pruning_fn=False, + num_processes=1, ): - # Skipped as currently not supported. - # Todo: add support for custom pruning_fn function. - if not use_global_unstructured and use_custom_pruning_fn: - return + model = TestModel() - model = PruningModel() - model.validation_step = None - model.test_step = None + # Weights are random. None is 0 + assert torch.all(model.layer.mlp_2.weight != 0) + pruning_kwargs = {"pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured} if parameters_to_prune: - parameters_to_prune = [ - (model.layer["mlp_1"], "weight"), - (model.layer["mlp_2"], "weight") - ] - - else: - parameters_to_prune = None - - assert torch.sum(model.layer["mlp_2"].weight == 0) == 0 - - class TestPruningMethod(pytorch_prune.BasePruningMethod): - """Prune every other entry in a tensor - """ - PRUNING_TYPE = 'unstructured' - - def compute_mask(self, t, default_mask): - mask = default_mask.clone() - mask.view(-1)[::2] = 0 - return mask - - @classmethod - def apply(cls, module, name, amount): - r"""Adds the forward pre-hook that enables pruning on the fly and - the reparametrization of a tensor in terms of the original tensor - and the pruning mask. - - Args: - module (nn.Module): module containing the tensor to prune - name (str): parameter name within ``module`` on which pruning - will act. - amount (int or float): quantity of parameters to prune. - If ``float``, should be between 0.0 and 1.0 and represent the - fraction of parameters to prune. If ``int``, it represents the - absolute number of parameters to prune. - """ - return super(TestPruningMethod, cls).apply( - module, name, amount=amount - ) - - custom_pruning_fn = TestPruningMethod - - pruning_funcs_structured = [ - "ln_structured", - "random_structured", - ] - - pruning_funcs_unstructured = [ - "l1_unstructured", - "random_unstructured", - ] - - if use_global_unstructured: - pruning_list = pruning_funcs_unstructured - else: - pruning_list = pruning_funcs_unstructured + pruning_funcs_structured - - rand_idx = np.random.randint(len(pruning_list)) - pruning_fn = pruning_list[rand_idx] - - model_pruning_args = { - "pruning_fn": custom_pruning_fn if use_custom_pruning_fn else pruning_fn , - "parameters_to_prune": parameters_to_prune, - "amount": 0.3, - "use_global_unstructured": use_global_unstructured, - } - - if "unstructured" not in pruning_fn: - model_pruning_args["pruning_dim"] = 0 - + pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] + if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): + pruning_kwargs["pruning_dim"] = 0 if pruning_fn == "ln_structured": - model_pruning_args["pruning_norm"] = 1 + pruning_kwargs["pruning_norm"] = 1 + + # Misconfiguration checks + if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured") and use_global_unstructured: + with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"): + ModelPruning(**pruning_kwargs) + return + if not isinstance(pruning_fn, str) and not use_global_unstructured: + with pytest.raises(MisconfigurationException, match="currently only supported with"): + ModelPruning(**pruning_kwargs) + return + + pruning = ModelPruning(**pruning_kwargs) trainer = Trainer( default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, limit_train_batches=10, limit_val_batches=2, max_epochs=10, accelerator=accelerator, gpus=gpus, num_processes=num_processes, - callbacks=ModelPruning(**model_pruning_args) + callbacks=pruning, ) trainer.fit(model) - _ = trainer.test(model) - - if accelerator is None: - assert torch.sum(model.layer["mlp_2"].weight == 0) > 0 + trainer.test(model) + if accelerator not in ("ddp_cpu", "ddp_spawn"): + # Check some have been pruned + assert torch.any(model.layer.mlp_2.weight == 0) -def test_with_pruning_callback_misconfiguration(tmpdir): - model_pruning_args = { - "parameter_names": ["chocolat"], - } - with pytest.raises(MisconfigurationException, match='provided parameter_names'): - _ = ModelPruning(**model_pruning_args) - - model_pruning_args = { - "parameter_names": ["weight"], - "pruning_fn": model_pruning_args - } - - with pytest.raises(MisconfigurationException, match='pruning_fn is expected to be the str in'): - _ = ModelPruning(**model_pruning_args) - - model_pruning_args = { - "parameter_names": ["weight"], - "pruning_fn": "random_structured" - } - - with pytest.raises(MisconfigurationException, match='should be provided'): - _ = ModelPruning(**model_pruning_args) - - model_pruning_args = { - "parameter_names": ["weight"], - "pruning_fn": "ln_structured", - "pruning_dim": 0 - } - - with pytest.raises(MisconfigurationException, match='requesting `ln_structured` pruning, the `pruning_norm`'): - _ = ModelPruning(**model_pruning_args) +def test_pruning_misconfiguration(): + with pytest.raises(MisconfigurationException, match="provided parameter_names"): + ModelPruning(parameter_names=["chocolate"]) + with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["): + ModelPruning(parameter_names=["weight"], pruning_fn={}) # noqa + with pytest.raises(MisconfigurationException, match="should be provided"): + ModelPruning(parameter_names=["weight"], pruning_fn="random_structured") + with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"): + ModelPruning(parameter_names=["weight"], pruning_fn="ln_structured", pruning_dim=0) @pytest.mark.parametrize("parameters_to_prune", [False, True]) @pytest.mark.parametrize("use_global_unstructured", [False, True]) -@pytest.mark.parametrize("use_custom_pruning_fn", [False, True]) -def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, use_custom_pruning_fn): +@pytest.mark.parametrize( + "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod()] +) +def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn): train_with_pruning_callback( - tmpdir, parameters_to_prune, use_global_unstructured, - accelerator=None, gpus=None, num_processes=1, use_custom_pruning_fn=use_custom_pruning_fn) + tmpdir, + parameters_to_prune=parameters_to_prune, + use_global_unstructured=use_global_unstructured, + pruning_fn=pruning_fn, + ) @pytest.mark.parametrize("parameters_to_prune", [False, True]) @pytest.mark.parametrize("use_global_unstructured", [False, True]) -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', - reason="test should be run outside of pytest") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", "0") == "1", reason="test should be run outside of pytest" +) def test_pruning_callback_ddp(tmpdir, use_global_unstructured, parameters_to_prune): train_with_pruning_callback( - tmpdir, parameters_to_prune, use_global_unstructured, - accelerator="ddp", gpus=2, num_processes=0) + tmpdir, + parameters_to_prune=parameters_to_prune, + use_global_unstructured=use_global_unstructured, + accelerator="ddp", + gpus=2, + ) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_spawn(tmpdir): - train_with_pruning_callback(tmpdir, False, True, accelerator="ddp_spawn", gpus=2, num_processes=None) + train_with_pruning_callback(tmpdir, use_global_unstructured=True, accelerator="ddp_spawn", gpus=2) @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_cpu(tmpdir): - train_with_pruning_callback(tmpdir, True, False, accelerator="ddp_cpu", gpus=None, num_processes=2) + train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) From c1af925c4aeeffd833fd37936a506ffb5c614857 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 04:50:14 +0100 Subject: [PATCH 03/27] Fix misconfig test --- tests/callbacks/test_pruning.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index c0706ce307ad0..1cd74ae140f2f 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -33,11 +33,15 @@ class TestModel(BoringModel): def __init__(self): super().__init__() - self.layer = Sequential(OrderedDict([ - ("mlp_1", nn.Linear(32, 32)), - ("mlp_2", nn.Linear(32, 32)), - ("mlp_3", nn.Linear(32, 2)), - ])) + self.layer = Sequential( + OrderedDict( + [ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32)), + ("mlp_3", nn.Linear(32, 2)), + ] + ) + ) class TestPruningMethod(pytorch_prune.BasePruningMethod): @@ -111,14 +115,14 @@ def train_with_pruning_callback( def test_pruning_misconfiguration(): - with pytest.raises(MisconfigurationException, match="provided parameter_names"): + with pytest.raises(MisconfigurationException, match=r"chocolate isn't in \('weight', 'bias'\)"): ModelPruning(parameter_names=["chocolate"]) with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["): - ModelPruning(parameter_names=["weight"], pruning_fn={}) # noqa + ModelPruning(pruning_fn={}) # noqa with pytest.raises(MisconfigurationException, match="should be provided"): - ModelPruning(parameter_names=["weight"], pruning_fn="random_structured") + ModelPruning(pruning_fn="random_structured") with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"): - ModelPruning(parameter_names=["weight"], pruning_fn="ln_structured", pruning_dim=0) + ModelPruning(pruning_fn="ln_structured", pruning_dim=0) @pytest.mark.parametrize("parameters_to_prune", [False, True]) From 8e79ac97cd5f68ad66a46b2d0a1abe83f9cc061d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 13:50:33 +0100 Subject: [PATCH 04/27] Fix tests --- pytorch_lightning/callbacks/pruning.py | 29 +++++++++++++------------- tests/callbacks/test_pruning.py | 22 +++++++++---------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index f95fab3b53071..32ebbdd5184f2 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -20,7 +20,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Type, Union import torch.nn.utils.prune as pytorch_prune from torch import nn @@ -161,7 +161,7 @@ def __init__( pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) else: - if not isinstance(pruning_fn, pytorch_prune.BasePruningMethod): + if not self.is_pruning_method(pruning_fn): raise MisconfigurationException( f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}" @@ -195,7 +195,7 @@ def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn """ return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs): + def _create_pruning_fn(self, pruning_fn: str, **kwargs): """ This function takes `pruning_fn`, a function name. @@ -207,11 +207,11 @@ def _create_pruning_fn(self, pruning_fn: str, *args, **kwargs): pruning_fn = _PYTORCH_PRUNING_METHOD[pruning_fn] self._global_kwargs = kwargs return pruning_fn - return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], *args, **kwargs) + return ModelPruning._wrap_pruning_fn(_PYTORCH_PRUNING_FUNCTIONS[pruning_fn], **kwargs) @staticmethod - def _wrap_pruning_fn(pruning_fn, *args, **kwargs): - return partial(pruning_fn, *args, **kwargs) + def _wrap_pruning_fn(pruning_fn, **kwargs): + return partial(pruning_fn, **kwargs) def _make_pruning_permanent(self): for module, param_name in self._parameters_to_prune: @@ -246,15 +246,10 @@ def _apply_local_pruning(self, amount: float): self.pruning_fn(module, name=param, amount=amount) def _resolve_global_kwargs(self, amount: float): - kwargs = {} self._global_kwargs["amount"] = amount - params = inspect.signature(self.pruning_fn).parameters - for p_name in params: - if p_name != "self": - param = self._global_kwargs.get(p_name) - if param is not None: - kwargs[p_name] = param - return kwargs + params = set(inspect.signature(self.pruning_fn).parameters) + params.discard("self") + return {k: v for k, v in self._global_kwargs.items() if k in params} def _apply_global_pruning(self, amount: float): pytorch_prune.global_unstructured( @@ -337,3 +332,9 @@ def sanitize_parameters_to_prune( ) return parameters_to_prune + + @staticmethod + def is_pruning_method(method) -> bool: + if not inspect.isclass(method): + return False + return issubclass(method, pytorch_prune.BasePruningMethod) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 1cd74ae140f2f..0bf5de6070171 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -33,15 +33,11 @@ class TestModel(BoringModel): def __init__(self): super().__init__() - self.layer = Sequential( - OrderedDict( - [ - ("mlp_1", nn.Linear(32, 32)), - ("mlp_2", nn.Linear(32, 32)), - ("mlp_3", nn.Linear(32, 2)), - ] - ) - ) + self.layer = Sequential(OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32)), + ("mlp_3", nn.Linear(32, 2)), + ])) class TestPruningMethod(pytorch_prune.BasePruningMethod): @@ -75,6 +71,8 @@ def train_with_pruning_callback( pruning_kwargs = {"pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured} if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] + else: + pruning_kwargs["parameter_names"] = ["weight"] if isinstance(pruning_fn, str) and pruning_fn.endswith("_structured"): pruning_kwargs["pruning_dim"] = 0 if pruning_fn == "ln_structured": @@ -85,7 +83,7 @@ def train_with_pruning_callback( with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"): ModelPruning(**pruning_kwargs) return - if not isinstance(pruning_fn, str) and not use_global_unstructured: + if ModelPruning.is_pruning_method(pruning_fn) and not use_global_unstructured: with pytest.raises(MisconfigurationException, match="currently only supported with"): ModelPruning(**pruning_kwargs) return @@ -109,7 +107,7 @@ def train_with_pruning_callback( trainer.fit(model) trainer.test(model) - if accelerator not in ("ddp_cpu", "ddp_spawn"): + if not accelerator: # Check some have been pruned assert torch.any(model.layer.mlp_2.weight == 0) @@ -128,7 +126,7 @@ def test_pruning_misconfiguration(): @pytest.mark.parametrize("parameters_to_prune", [False, True]) @pytest.mark.parametrize("use_global_unstructured", [False, True]) @pytest.mark.parametrize( - "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod()] + "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod] ) def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn): train_with_pruning_callback( From cdd064ea7616454003924d193038bd2f491c2ba8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 13:57:09 +0100 Subject: [PATCH 05/27] Improve error message --- pytorch_lightning/callbacks/pruning.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 32ebbdd5184f2..26c6805f53297 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -159,19 +159,18 @@ def __init__( pruning_kwargs["n"] = pruning_norm pruning_kwargs["dim"] = pruning_dim pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) - - else: - if not self.is_pruning_method(pruning_fn): - raise MisconfigurationException( - f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" - f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}" - ) - + elif self.is_pruning_method(pruning_fn): if not use_global_unstructured: # TODO: currently not supported raise MisconfigurationException( "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." ) + else: + raise MisconfigurationException( + f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}" + f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}." + " HINT: if passing a `BasePruningMethod`, pass the the class, not an instance" + ) if use_global_unstructured and pruning_fn.PRUNING_TYPE != "unstructured": raise MisconfigurationException( From 0e32388504704488763a71309b0cf05fcb395277 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 13:59:35 +0100 Subject: [PATCH 06/27] Reduce whitespace --- pytorch_lightning/callbacks/pruning.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 26c6805f53297..25993644d53df 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -76,16 +76,14 @@ def __init__( (model.mlp_2, "weight") ] - trainer = Trainer( - callbacks=[ - ModelPruning( - pruning_fn='l1_unstructured', - parameters_to_prune=parameters_to_prune, - amount=0.01, - use_global_unstructured=True, - ) - ] - ) + trainer = Trainer(callbacks=[ + ModelPruning( + pruning_fn='l1_unstructured', + parameters_to_prune=parameters_to_prune, + amount=0.01, + use_global_unstructured=True, + ) + ]) When `parameters_to_prune` is None, `parameters_to_prune` will contains all parameters from the model. The user can override `filter_parameters_to_prune` to filter any nn.Module to be pruned. From 7fa8acd6f22e2e0354577fbc79c3811668622e63 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 14:41:25 +0100 Subject: [PATCH 07/27] WIP --- pytorch_lightning/callbacks/pruning.py | 102 +++++++++++++------------ tests/callbacks/test_pruning.py | 12 ++- 2 files changed, 63 insertions(+), 51 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 25993644d53df..ec37f97cb9622 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -20,7 +20,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch.nn.utils.prune as pytorch_prune from torch import nn @@ -44,6 +44,8 @@ "random_unstructured": pytorch_prune.RandomUnstructured, } +_PARAM_LIST = List[Tuple[nn.Module, str]] + class ModelPruning(Callback): PARAMETER_NAMES = ("weight", "bias") @@ -51,12 +53,12 @@ class ModelPruning(Callback): def __init__( self, pruning_fn: Optional[Union[Callable, str]] = None, - parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None, + parameters_to_prune: Optional[_PARAM_LIST] = None, parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, - amount: Union[int, float, Callable] = 0.5, + amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, make_pruning_permanent: bool = True, - use_lottery_ticket_hypothesis: bool = True, + use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, pruning_dim: Optional[int] = None, pruning_norm: Optional[int] = None, ) -> None: @@ -85,41 +87,38 @@ def __init__( ) ]) - When `parameters_to_prune` is None, `parameters_to_prune` will contains all parameters from the model. - The user can override `filter_parameters_to_prune` to filter any nn.Module to be pruned. + When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model. + The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned. Args: - pruning_fn: function from torch.nn.utils.prune module - or your based own subclasses from PyTorch ``BasePruningMethod``. - Can be string e.g. `"l1_unstructured"`. - See pytorch docs for more details. + pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass. + Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details. - parameters_to_prune: list of strings or list of tuple with - nn.Module and its associated string name parameters. + parameters_to_prune: List of strings or list of tuples ``(nn.Module, "parameter_name_string")``. - parameter_names: List of parameter names to be used from nn.Module. - Can either be `weight` or `bias`. + parameter_names: List of parameter names to be pruned from the nn.Module. + Can either be ``"weight"`` or ``"bias"``. use_global_unstructured: Whether to apply pruning globally on the model. - If parameters_to_prune is provided, global_unstructured will be restricted on them. + If ``parameters_to_prune`` is provided, global unstructured will be restricted on them. + + amount: Quantity of parameters to prune: - amount: quantity of parameters to prune: + - ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune. + - ``int``. Represents the absolute number of parameters to prune. + - ``Callable``. For dynamic values. Will be called every epoch. Should return a value. - - float, between 0.0 and 1.0. represents the fraction of parameters to prune. - - int, represents the absolute number of parameters to prune. - - Callable, for dynamic values. will be called every epoch. + make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks on fit end. - make_pruning_permanent: if True then all reparametrization pre-hooks and tensors - with mask will be removed on fit end. + use_lottery_ticket_hypothesis: See "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf): - use_lottery_ticket_hypothesis: Whether to use algorithm describes in - "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) + - ``bool``. Whether to apply it or not. + - ``Callable``. For dynamic values. Will be called every epoch. Should return a bool - pruning_dim: if you are using structured pruning method you need - to specify dimension. + pruning_dim: If you are using a structured pruning method you need to specify the dimension. - pruning_norm: if you are using ln_structured you need to specify norm. + pruning_norm: If you are using ``ln_structured`` you need to specify the norm. """ @@ -186,13 +185,13 @@ def __init__( self.amount = amount - def filter_parameters_to_prune(self, parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None): + def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: """ This function can be overridden to control which module to prune. """ return parameters_to_prune - def _create_pruning_fn(self, pruning_fn: str, **kwargs): + def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytorch_prune.BasePruningMethod]: """ This function takes `pruning_fn`, a function name. @@ -214,20 +213,7 @@ def _make_pruning_permanent(self): for module, param_name in self._parameters_to_prune: pytorch_prune.remove(module, param_name) - def _resolve_amount(self, current_epoch: int) -> float: - return self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount - def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): - """ - "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf) algorithm: - - 1. Randomly initialize a neural network f(x;θ0)(where θ0 ∼Dθ). - 2. Train the network for j iterations, arriving at parameters θj . - 3. Prune p% of the parameters in θj , creating a mask m. - 4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; m⊙θ0). - - This function is responsible of step 4. - """ trained = getattr(module, tensor_name) orig = getattr(orig_module, tensor_name) if trained is None or orig is None: @@ -235,8 +221,23 @@ def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, t trained.data = orig.data.to(trained.device) def apply_lottery_ticket_hypothesis(self): - for (mod, tensor_name), (initial_mod, _) in zip(self._parameters_to_prune, self._initial_parameters_to_prune): - self._restore_original_weights(mod, initial_mod, tensor_name) + """ + Lottery ticket hypothesis algorithm (see page 2 of the paper): + + 1. Randomly initialize a neural network f(x; θ_0) (where θ_0 ∼ D_θ). + 2. Train the network for j iterations, arriving at parameters θ_j . + 3. Prune p% of the parameters in θ_j, creating a mask m. + 4. Reset the remaining parameters to their values in θ_0, creating the winning ticket f(x; m⊙θ_0). + + This function implements the step 4. + """ + for (new, new_name), (old, old_name) in zip(self._parameters_to_prune, self._initial_parameters_to_prune): + trained = getattr(new, new_name) + orig = getattr(old, new_name) + assert new_name == old_name + if trained is None or orig is None: + return + trained.data = orig.data.to(trained.device) def _apply_local_pruning(self, amount: float): for module, param in self._parameters_to_prune: @@ -254,8 +255,7 @@ def _apply_global_pruning(self, amount: float): ) def apply_pruning(self, current_epoch: int): - amount = self._resolve_amount(current_epoch) - + amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount # the user could control the pruning frequency with amount_fn if not amount: return @@ -265,7 +265,11 @@ def apply_pruning(self, current_epoch: int): else: self._apply_local_pruning(amount) - if self._use_lottery_ticket_hypothesis: + if ( + self._use_lottery_ticket_hypothesis(current_epoch) + if isinstance(self._use_lottery_ticket_hypothesis, Callable) + else self._use_lottery_ticket_hypothesis + ): self.apply_lottery_ticket_hypothesis() def on_before_accelerator_backend_setup(self, trainer, pl_module): @@ -288,9 +292,9 @@ def on_epoch_end(self, trainer, pl_module): @staticmethod def sanitize_parameters_to_prune( pl_module: LightningModule, - parameters_to_prune: Optional[List[Tuple[nn.Module, str]]] = None, + parameters_to_prune: Optional[_PARAM_LIST] = None, parameters: Optional[List[str]] = None, - ) -> List[Tuple[nn.Module, str]]: + ) -> _PARAM_LIST: """ This function is responsible to check provided ``parameters_to_prune` and `parameters`. If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model. @@ -331,7 +335,7 @@ def sanitize_parameters_to_prune( return parameters_to_prune @staticmethod - def is_pruning_method(method) -> bool: + def is_pruning_method(method: Any) -> bool: if not inspect.isclass(method): return False return issubclass(method, pytorch_prune.BasePruningMethod) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 0bf5de6070171..384e5140ae46c 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -59,6 +59,7 @@ def train_with_pruning_callback( parameters_to_prune=False, use_global_unstructured=False, pruning_fn="l1_unstructured", + use_lottery_ticket_hypothesis=False, accelerator=None, gpus=None, num_processes=1, @@ -68,7 +69,7 @@ def train_with_pruning_callback( # Weights are random. None is 0 assert torch.all(model.layer.mlp_2.weight != 0) - pruning_kwargs = {"pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured} + pruning_kwargs = {"pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured, "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis} if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] else: @@ -128,12 +129,14 @@ def test_pruning_misconfiguration(): @pytest.mark.parametrize( "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod] ) -def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn): +@pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True]) +def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn, use_lottery_ticket_hypothesis): train_with_pruning_callback( tmpdir, parameters_to_prune=parameters_to_prune, use_global_unstructured=use_global_unstructured, pruning_fn=pruning_fn, + use_lottery_ticket_hypothesis=use_lottery_ticket_hypothesis, ) @@ -161,3 +164,8 @@ def test_pruning_callback_ddp_spawn(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) + + +# TODO: lottery ticket tests +# TODO: iterative pruning tests +# TODO: saving tests From 9820e7ab468cb3183588545088075c7884883ab0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 15:12:08 +0100 Subject: [PATCH 08/27] TODOs --- tests/callbacks/test_pruning.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 384e5140ae46c..8e0c94c7f05b9 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -169,3 +169,6 @@ def test_pruning_callback_ddp_cpu(tmpdir): # TODO: lottery ticket tests # TODO: iterative pruning tests # TODO: saving tests +# TODO: sparsity history and tracking +# TODO: allow resampling +# TODO: second chance From c9bb9972e35b2b1ef1608627e445cf31f674419a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 5 Feb 2021 15:32:55 +0100 Subject: [PATCH 09/27] _MODULE_CONTAINERS --- pytorch_lightning/callbacks/pruning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index ec37f97cb9622..a6e545aba28aa 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -24,7 +24,6 @@ import torch.nn.utils.prune as pytorch_prune from torch import nn -from torch.nn.modules.container import ModuleDict, ModuleList from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule @@ -45,6 +44,7 @@ } _PARAM_LIST = List[Tuple[nn.Module, str]] +_MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) class ModelPruning(Callback): @@ -158,7 +158,6 @@ def __init__( pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) elif self.is_pruning_method(pruning_fn): if not use_global_unstructured: - # TODO: currently not supported raise MisconfigurationException( "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." ) @@ -302,7 +301,7 @@ def sanitize_parameters_to_prune( parameters = parameters or ModelPruning.PARAMETER_NAMES current_modules = [ - m for m in pl_module.modules() if not isinstance(m, (LightningModule, ModuleDict, ModuleList)) + m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS) ] if parameters_to_prune is None: From 1b11bd2a4fe3d8d28ecd0cde973fcaf78482d17a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 02:58:20 +0100 Subject: [PATCH 10/27] Add LTH test --- pytorch_lightning/callbacks/pruning.py | 29 +++++--------- tests/callbacks/test_pruning.py | 54 +++++++++++++++++++++----- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 0e4cd59e43c67..4047eba5fd6f5 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -28,7 +28,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities.exceptions import MisconfigurationException - _PYTORCH_PRUNING_FUNCTIONS = { "ln_structured": pytorch_prune.ln_structured, "l1_unstructured": pytorch_prune.l1_unstructured, @@ -52,7 +51,7 @@ class ModelPruning(Callback): def __init__( self, - pruning_fn: Optional[Union[Callable, str]] = None, + pruning_fn: Union[Callable, str], parameters_to_prune: Optional[_PARAM_LIST] = None, parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, @@ -63,11 +62,8 @@ def __init__( pruning_norm: Optional[int] = None, ) -> None: """ - - Pruning Callback relying on PyTorch prune utils. - - This callback is responsible to prune networks parameters - during your training. + Model pruning Callback, using PyTorch's prune utilities. + This callback is responsible of pruning networks parameters during training. Find here the PyTorch (Pruning Tutorial)[https://pytorch.org/tutorials/intermediate/pruning_tutorial.html] @@ -114,7 +110,7 @@ def __init__( use_lottery_ticket_hypothesis: See "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf): - ``bool``. Whether to apply it or not. - - ``Callable``. For dynamic values. Will be called every epoch. Should return a bool + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. pruning_dim: If you are using a structured pruning method you need to specify the dimension. @@ -264,11 +260,9 @@ def apply_pruning(self, current_epoch: int): else: self._apply_local_pruning(amount) - if ( - self._use_lottery_ticket_hypothesis(current_epoch) - if isinstance(self._use_lottery_ticket_hypothesis, Callable) - else self._use_lottery_ticket_hypothesis - ): + if self._use_lottery_ticket_hypothesis(current_epoch) if isinstance( + self._use_lottery_ticket_hypothesis, Callable + ) else self._use_lottery_ticket_hypothesis: self.apply_lottery_ticket_hypothesis() def on_before_accelerator_backend_setup(self, trainer, pl_module): @@ -282,7 +276,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): # make a copy of copy of original weights. self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune] - def on_epoch_end(self, trainer, pl_module): + def on_train_epoch_end(self, trainer, pl_module, *args): self.apply_pruning(trainer.current_epoch) if self.make_pruning_permanent: @@ -300,15 +294,12 @@ def sanitize_parameters_to_prune( """ parameters = parameters or ModelPruning.PARAMETER_NAMES - current_modules = [ - m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS) - ] + current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] if parameters_to_prune is None: parameters_to_prune = [(m, p) for p in parameters for m in current_modules if hasattr(m, p)] elif ( - isinstance(parameters_to_prune, (list, tuple)) - and len(parameters_to_prune) > 0 + isinstance(parameters_to_prune, (list, tuple)) and len(parameters_to_prune) > 0 and all(len(p) == 2 for p in parameters_to_prune) and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune) ): diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 1e6d6a6cfb0da..326aec6a8555b 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -22,7 +22,7 @@ from torch.nn import Sequential from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelPruning +from pytorch_lightning.callbacks import Callback, ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -33,11 +33,13 @@ class TestModel(BoringModel): def __init__(self): super().__init__() - self.layer = Sequential(OrderedDict([ - ("mlp_1", nn.Linear(32, 32)), - ("mlp_2", nn.Linear(32, 32)), - ("mlp_3", nn.Linear(32, 2)), - ])) + self.layer = Sequential( + OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 32)), + ("mlp_3", nn.Linear(32, 2)), + ]) + ) class TestPruningMethod(pytorch_prune.BasePruningMethod): @@ -69,7 +71,12 @@ def train_with_pruning_callback( # Weights are random. None is 0 assert torch.all(model.layer.mlp_2.weight != 0) - pruning_kwargs = {"pruning_fn": pruning_fn, "amount": 0.3, "use_global_unstructured": use_global_unstructured, "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis} + pruning_kwargs = { + "pruning_fn": pruning_fn, + "amount": 0.3, + "use_global_unstructured": use_global_unstructured, + "use_lottery_ticket_hypothesis": use_lottery_ticket_hypothesis + } if parameters_to_prune: pruning_kwargs["parameters_to_prune"] = [(model.layer.mlp_1, "weight"), (model.layer.mlp_2, "weight")] else: @@ -115,7 +122,7 @@ def train_with_pruning_callback( def test_pruning_misconfiguration(): with pytest.raises(MisconfigurationException, match=r"chocolate isn't in \('weight', 'bias'\)"): - ModelPruning(parameter_names=["chocolate"]) + ModelPruning(pruning_fn="l1_unstructured", parameter_names=["chocolate"]) with pytest.raises(MisconfigurationException, match=r"expected to be a str in \["): ModelPruning(pruning_fn={}) # noqa with pytest.raises(MisconfigurationException, match="should be provided"): @@ -130,7 +137,9 @@ def test_pruning_misconfiguration(): "pruning_fn", ["l1_unstructured", "random_unstructured", "ln_structured", "random_structured", TestPruningMethod] ) @pytest.mark.parametrize("use_lottery_ticket_hypothesis", [False, True]) -def test_pruning_callback(tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn, use_lottery_ticket_hypothesis): +def test_pruning_callback( + tmpdir, use_global_unstructured, parameters_to_prune, pruning_fn, use_lottery_ticket_hypothesis +): train_with_pruning_callback( tmpdir, parameters_to_prune=parameters_to_prune, @@ -164,3 +173,30 @@ def test_pruning_callback_ddp_spawn(tmpdir): @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) + + +def test_pruning_lth_callable(tmpdir): + model = TestModel() + + class ModelPruningTestCallback(ModelPruning): + lth_calls = 0 + + def apply_lottery_ticket_hypothesis(self): + super().apply_lottery_ticket_hypothesis() + self.lth_calls += 1 + + pruning = ModelPruningTestCallback("l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2)) + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, + limit_train_batches=10, + limit_val_batches=2, + max_epochs=5, + callbacks=pruning, + ) + assert pruning.lth_calls == 0 + trainer.fit(model) + assert pruning.lth_calls == trainer.max_epochs // 2 From f853b55c13b2ddd9ec95c042dad4682c4d5b0dbd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 04:46:48 +0100 Subject: [PATCH 11/27] Allow resampling --- pytorch_lightning/callbacks/pruning.py | 66 +++++++++++++++++--------- tests/callbacks/test_pruning.py | 20 ++++++-- 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 4047eba5fd6f5..542e4d5c4973c 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -21,6 +21,7 @@ from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union +import torch import torch.nn.utils.prune as pytorch_prune from torch import nn @@ -58,6 +59,7 @@ def __init__( amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, make_pruning_permanent: bool = True, use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, + resample_parameters: bool = False, pruning_dim: Optional[int] = None, pruning_norm: Optional[int] = None, ) -> None: @@ -112,6 +114,9 @@ def __init__( - ``bool``. Whether to apply it or not. - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will + be resampled, otherwise, the exact original parameters will be used. + pruning_dim: If you are using a structured pruning method you need to specify the dimension. pruning_norm: If you are using ``ln_structured`` you need to specify the norm. @@ -121,14 +126,15 @@ def __init__( self._use_global_unstructured = use_global_unstructured self._parameters_to_prune = parameters_to_prune self._use_lottery_ticket_hypothesis = use_lottery_ticket_hypothesis + self._resample_parameters = resample_parameters self._parameter_names = parameter_names or self.PARAMETER_NAMES self._global_kwargs = {} - self._initial_parameters_to_prune = None + self._original_layers = None - for param_name in self._parameter_names: - if param_name not in self.PARAMETER_NAMES: + for name in self._parameter_names: + if name not in self.PARAMETER_NAMES: raise MisconfigurationException( - f"The provided `parameter_names`: {param_name} isn't in {self.PARAMETER_NAMES}" + f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}" ) if isinstance(pruning_fn, str): @@ -225,18 +231,29 @@ def apply_lottery_ticket_hypothesis(self): 4. Reset the remaining parameters to their values in θ_0, creating the winning ticket f(x; m⊙θ_0). This function implements the step 4. + + The ``resample_parameters`` argument can be used to reset the parameters with a new θ ∼ D_θ """ - for (new, new_name), (old, old_name) in zip(self._parameters_to_prune, self._initial_parameters_to_prune): - trained = getattr(new, new_name) - orig = getattr(old, new_name) - assert new_name == old_name - if trained is None or orig is None: + + def copy_param(new, old, name: str) -> None: + dst = getattr(new, name) + src = getattr(old, name) + if dst is None or src is None or not isinstance(dst, torch.Tensor) or not isinstance(src, torch.Tensor): return - trained.data = orig.data.to(trained.device) + dst.data = src.data.to(dst.device) + + for d in self._original_layers.values(): + copy, names = d["data"], d["names"] + if self._resample_parameters and hasattr(copy, "reset_parameters"): + copy = deepcopy(copy) # keep the original parameters + copy.reset_parameters() + for i, name in names: + new, new_name = self._parameters_to_prune[i] + copy_param(new, copy, name) def _apply_local_pruning(self, amount: float): - for module, param in self._parameters_to_prune: - self.pruning_fn(module, name=param, amount=amount) + for module, name in self._parameters_to_prune: + self.pruning_fn(module, name=name, amount=amount) def _resolve_global_kwargs(self, amount: float): self._global_kwargs["amount"] = amount @@ -267,14 +284,19 @@ def apply_pruning(self, current_epoch: int): def on_before_accelerator_backend_setup(self, trainer, pl_module): parameters_to_prune = self.sanitize_parameters_to_prune( - pl_module, self._parameters_to_prune, parameters=self._parameter_names + pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune) if self._use_lottery_ticket_hypothesis: - # make a copy of copy of original weights. - self._initial_parameters_to_prune = [(deepcopy(m), n) for m, n in self._parameters_to_prune] + # group modules by id. Each entry has a copy of the initial data + # and a list of the associated parameter names to prune + self._original_layers = {} + for i, (module, name) in enumerate(self._parameters_to_prune): + id_ = id(module) + self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) + self._original_layers[id_]["names"].append((i, name)) def on_train_epoch_end(self, trainer, pl_module, *args): self.apply_pruning(trainer.current_epoch) @@ -286,13 +308,13 @@ def on_train_epoch_end(self, trainer, pl_module, *args): def sanitize_parameters_to_prune( pl_module: LightningModule, parameters_to_prune: Optional[_PARAM_LIST] = None, - parameters: Optional[List[str]] = None, + parameter_names: Optional[List[str]] = None, ) -> _PARAM_LIST: """ This function is responsible to check provided ``parameters_to_prune` and `parameters`. If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model. """ - parameters = parameters or ModelPruning.PARAMETER_NAMES + parameters = parameter_names or ModelPruning.PARAMETER_NAMES current_modules = [m for m in pl_module.modules() if not isinstance(m, _MODULE_CONTAINERS)] @@ -304,12 +326,12 @@ def sanitize_parameters_to_prune( and all(isinstance(a, nn.Module) and isinstance(b, str) for a, b in parameters_to_prune) ): missing_modules, missing_parameters = [], [] - for module, param_name in parameters_to_prune: + for module, name in parameters_to_prune: if module not in current_modules: missing_modules.append(module) continue - if not hasattr(module, param_name): - missing_parameters.append(param_name) + if not hasattr(module, name): + missing_parameters.append(name) if missing_modules or missing_parameters: raise MisconfigurationException( @@ -318,8 +340,8 @@ def sanitize_parameters_to_prune( ) else: raise MisconfigurationException( - "The provided `parameters_to_prune` should either be list of tuple " - "with 2 elements: (nn.Module in your model, parameter_name_to_prune) or None" + "The provided `parameters_to_prune` should either be list of tuple" + " with 2 elements: (nn.Module, parameter_name_to_prune) or None" ) return parameters_to_prune diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 326aec6a8555b..7a42e5865508a 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -22,7 +22,7 @@ from torch.nn import Sequential from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback, ModelPruning +from pytorch_lightning.callbacks import ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -175,7 +175,8 @@ def test_pruning_callback_ddp_cpu(tmpdir): train_with_pruning_callback(tmpdir, parameters_to_prune=True, accelerator="ddp_cpu", num_processes=2) -def test_pruning_lth_callable(tmpdir): +@pytest.mark.parametrize("resample_parameters", (False, True)) +def test_pruning_lth_callable(tmpdir, resample_parameters): model = TestModel() class ModelPruningTestCallback(ModelPruning): @@ -185,7 +186,18 @@ def apply_lottery_ticket_hypothesis(self): super().apply_lottery_ticket_hypothesis() self.lth_calls += 1 - pruning = ModelPruningTestCallback("l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2)) + for d in self._original_layers.values(): + copy, names = d["data"], d["names"] + for i, name in names: + cur, cur_name = self._parameters_to_prune[i] + assert name == cur_name + actual, expected = getattr(cur, name).data, getattr(copy, name).data + allclose = torch.allclose(actual, expected) + assert not allclose if self._resample_parameters else allclose + + pruning = ModelPruningTestCallback( + "l1_unstructured", use_lottery_ticket_hypothesis=lambda e: bool(e % 2), resample_parameters=resample_parameters + ) trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, @@ -197,6 +209,6 @@ def apply_lottery_ticket_hypothesis(self): max_epochs=5, callbacks=pruning, ) - assert pruning.lth_calls == 0 trainer.fit(model) + assert pruning.lth_calls == trainer.max_epochs // 2 From 5a455e6fefbf58d9ab521ab7a6d958299110cabc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 05:16:49 +0100 Subject: [PATCH 12/27] Iterative pruning --- pytorch_lightning/callbacks/pruning.py | 33 +++++++++++++++++--------- tests/callbacks/test_pruning.py | 24 ++++++++++++++++--- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 542e4d5c4973c..25cd9c12681d3 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -57,6 +57,7 @@ def __init__( parameter_names: Optional[List[str]] = None, use_global_unstructured: bool = True, amount: Union[int, float, Callable[[int], Union[int, float]]] = 0.5, + apply_pruning: Union[bool, Callable[[int], bool]] = True, make_pruning_permanent: bool = True, use_lottery_ticket_hypothesis: Union[bool, Callable[[int], bool]] = True, resample_parameters: bool = False, @@ -107,6 +108,11 @@ def __init__( - ``int``. Represents the absolute number of parameters to prune. - ``Callable``. For dynamic values. Will be called every epoch. Should return a value. + apply_pruning: Whether to apply pruning. + + - ``bool``. Always apply it or not. + - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. + make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks on fit end. use_lottery_ticket_hypothesis: See "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf): @@ -177,7 +183,8 @@ def __init__( ) self.pruning_fn = pruning_fn - self.make_pruning_permanent = make_pruning_permanent + self._apply_pruning = apply_pruning + self._make_pruning_permanent = make_pruning_permanent if not isinstance(amount, (int, float, Callable)): raise MisconfigurationException( @@ -210,7 +217,7 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor def _wrap_pruning_fn(pruning_fn, **kwargs): return partial(pruning_fn, **kwargs) - def _make_pruning_permanent(self): + def make_pruning_permanent(self): for module, param_name in self._parameters_to_prune: pytorch_prune.remove(module, param_name) @@ -268,20 +275,13 @@ def _apply_global_pruning(self, amount: float): def apply_pruning(self, current_epoch: int): amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount - # the user could control the pruning frequency with amount_fn if not amount: return - if self._use_global_unstructured: self._apply_global_pruning(amount) else: self._apply_local_pruning(amount) - if self._use_lottery_ticket_hypothesis(current_epoch) if isinstance( - self._use_lottery_ticket_hypothesis, Callable - ) else self._use_lottery_ticket_hypothesis: - self.apply_lottery_ticket_hypothesis() - def on_before_accelerator_backend_setup(self, trainer, pl_module): parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names @@ -299,10 +299,21 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): self._original_layers[id_]["names"].append((i, name)) def on_train_epoch_end(self, trainer, pl_module, *args): + current_epoch = trainer.current_epoch + if not ( + self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning + ): + return self.apply_pruning(trainer.current_epoch) - if self.make_pruning_permanent: - self._make_pruning_permanent() + if ( + self._use_lottery_ticket_hypothesis(current_epoch) + if isinstance(self._use_lottery_ticket_hypothesis, Callable) else self._use_lottery_ticket_hypothesis + ): + self.apply_lottery_ticket_hypothesis() + + if self._make_pruning_permanent: + self.make_pruning_permanent() @staticmethod def sanitize_parameters_to_prune( diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 7a42e5865508a..b4e4d29dac179 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -189,9 +189,9 @@ def apply_lottery_ticket_hypothesis(self): for d in self._original_layers.values(): copy, names = d["data"], d["names"] for i, name in names: - cur, cur_name = self._parameters_to_prune[i] - assert name == cur_name - actual, expected = getattr(cur, name).data, getattr(copy, name).data + curr, curr_name = self._parameters_to_prune[i] + assert name == curr_name + actual, expected = getattr(curr, name).data, getattr(copy, name).data allclose = torch.allclose(actual, expected) assert not allclose if self._resample_parameters else allclose @@ -212,3 +212,21 @@ def apply_lottery_ticket_hypothesis(self): trainer.fit(model) assert pruning.lth_calls == trainer.max_epochs // 2 + + +def test_multiple_pruning_callbacks(tmpdir): + model = TestModel() + p1 = ModelPruning("l1_unstructured", apply_pruning=lambda e: bool(e % 2)) + p2 = ModelPruning("random_unstructured", apply_pruning=lambda e: not bool(e % 2)) + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + weights_summary=None, + checkpoint_callback=False, + logger=False, + limit_train_batches=10, + limit_val_batches=2, + max_epochs=5, + callbacks=[p1, p2], + ) + trainer.fit(model) From 789470f0e259f608c2da2084f0bcdab23de49012 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 06:16:17 +0100 Subject: [PATCH 13/27] Log pruning percentage --- pytorch_lightning/callbacks/pruning.py | 35 ++++++++++++++++++++------ tests/callbacks/test_pruning.py | 32 +++++++++++++++++++---- 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 25cd9c12681d3..b758edf542a3b 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -27,6 +27,7 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException _PYTORCH_PRUNING_FUNCTIONS = { @@ -63,6 +64,7 @@ def __init__( resample_parameters: bool = False, pruning_dim: Optional[int] = None, pruning_norm: Optional[int] = None, + verbose: bool = False, ) -> None: """ Model pruning Callback, using PyTorch's prune utilities. @@ -127,6 +129,8 @@ def __init__( pruning_norm: If you are using ``ln_structured`` you need to specify the norm. + verbose: Whether to log pruning percentage changes. + """ self._use_global_unstructured = use_global_unstructured @@ -192,6 +196,7 @@ def __init__( ) self.amount = amount + self._verbose = verbose def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: """ @@ -273,15 +278,29 @@ def _apply_global_pruning(self, amount: float): self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) - def apply_pruning(self, current_epoch: int): - amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount - if not amount: - return + def _get_pruned_percentage(self, module: nn.Module, name: str) -> float: + attr = f"{name}_mask" + if not hasattr(module, attr): + return 0.0 + mask = getattr(module, attr) + return (torch.sum(mask == 0) / mask.numel()).item() + + def apply_pruning(self, amount: Union[int, float]): + if self._verbose: + stats = [self._get_pruned_percentage(m, n) for m, n in self._parameters_to_prune] + if self._use_global_unstructured: self._apply_global_pruning(amount) else: self._apply_local_pruning(amount) + if self._verbose: + for i, (module, name) in enumerate(self._parameters_to_prune): + rank_zero_info( + f"Applied `{self.pruning_fn.__name__}` to `{module!r}.{name}` with amount={amount}." + f" Pruned {stats[i]:.2%} -> {self._get_pruned_percentage(module, name):.2%}" + ) + def on_before_accelerator_backend_setup(self, trainer, pl_module): parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names @@ -300,11 +319,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module, *args): current_epoch = trainer.current_epoch - if not ( - self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning - ): + prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning + amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount + if not prune or not amount: return - self.apply_pruning(trainer.current_epoch) + self.apply_pruning(amount) if ( self._use_lottery_ticket_hypothesis(current_epoch) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index b4e4d29dac179..6b273dc3b6192 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -14,6 +14,8 @@ import os import platform from collections import OrderedDict +from logging import INFO +from unittest import mock import pytest import torch @@ -21,7 +23,7 @@ from torch import nn from torch.nn import Sequential -from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -214,10 +216,17 @@ def apply_lottery_ticket_hypothesis(self): assert pruning.lth_calls == trainer.max_epochs // 2 -def test_multiple_pruning_callbacks(tmpdir): +@mock.patch.dict(os.environ, {}, clear=True) +def test_multiple_pruning_callbacks(tmpdir, caplog): + seed_everything(0) model = TestModel() - p1 = ModelPruning("l1_unstructured", apply_pruning=lambda e: bool(e % 2)) - p2 = ModelPruning("random_unstructured", apply_pruning=lambda e: not bool(e % 2)) + pruning_kwargs = { + 'parameters_to_prune': [(model.layer.mlp_1, "weight")], + 'make_pruning_permanent': False, + 'verbose': True + } + p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs) + p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs) trainer = Trainer( default_root_dir=tmpdir, progress_bar_refresh_rate=0, @@ -229,4 +238,17 @@ def test_multiple_pruning_callbacks(tmpdir): max_epochs=5, callbacks=[p1, p2], ) - trainer.fit(model) + + with caplog.at_level(INFO): + trainer.fit(model) + + actual = [m.strip() for m in caplog.messages[-5:]] + layer = "Linear(in_features=32, out_features=32, bias=True)" + expected = [ + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 0.00% -> 50.00%", + f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned 50.00% -> 62.50%", + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 62.50% -> 81.25%", + f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned 81.25% -> 85.94%", + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 85.94% -> 92.97%" + ] + assert actual == expected From 87c720f5422b0737208f221d26b2221d682a75ec Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 07:08:25 +0100 Subject: [PATCH 14/27] Properly make pruning permanent --- pytorch_lightning/callbacks/pruning.py | 29 +++++++++++++++++------- tests/callbacks/test_pruning.py | 31 +++++++++++++++++--------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index b758edf542a3b..7339e55f0c18e 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -115,7 +115,8 @@ def __init__( - ``bool``. Always apply it or not. - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. - make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks on fit end. + make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks + when training ends or the model is saved. use_lottery_ticket_hypothesis: See "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf): @@ -224,7 +225,11 @@ def _wrap_pruning_fn(pruning_fn, **kwargs): def make_pruning_permanent(self): for module, param_name in self._parameters_to_prune: - pytorch_prune.remove(module, param_name) + try: + pytorch_prune.remove(module, param_name) + except ValueError: + # pruning already made permanent + pass def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): trained = getattr(module, tensor_name) @@ -278,16 +283,16 @@ def _apply_global_pruning(self, amount: float): self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) - def _get_pruned_percentage(self, module: nn.Module, name: str) -> float: + def _get_pruned_stats(self, module: nn.Module, name: str) -> Tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): - return 0.0 + return 0, 1 mask = getattr(module, attr) - return (torch.sum(mask == 0) / mask.numel()).item() + return (mask == 0).sum().item(), mask.numel() def apply_pruning(self, amount: Union[int, float]): if self._verbose: - stats = [self._get_pruned_percentage(m, n) for m, n in self._parameters_to_prune] + stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] if self._use_global_unstructured: self._apply_global_pruning(amount) @@ -296,9 +301,12 @@ def apply_pruning(self, amount: Union[int, float]): if self._verbose: for i, (module, name) in enumerate(self._parameters_to_prune): + prev_mask_zeroes, prev_mask_size = stats[i] + curr_mask_zeroes, curr_mask_size = self._get_pruned_stats(module, name) rank_zero_info( - f"Applied `{self.pruning_fn.__name__}` to `{module!r}.{name}` with amount={amount}." - f" Pruned {stats[i]:.2%} -> {self._get_pruned_percentage(module, name):.2%}" + f"Applied `{self.pruning_fn.__name__}` to `{module!r}.{name}` with amount={amount}. Pruned:" + f" {prev_mask_zeroes} ({prev_mask_zeroes / prev_mask_size:.2%}) ->" + f" {curr_mask_zeroes} ({curr_mask_zeroes / curr_mask_size:.2%})" ) def on_before_accelerator_backend_setup(self, trainer, pl_module): @@ -331,6 +339,11 @@ def on_train_epoch_end(self, trainer, pl_module, *args): ): self.apply_lottery_ticket_hypothesis() + def on_train_end(self, *args): + if self._make_pruning_permanent: + self.make_pruning_permanent() + + def on_save_checkpoint(self, *args): if self._make_pruning_permanent: self.make_pruning_permanent() diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 6b273dc3b6192..d0479e95ac2fd 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -24,13 +24,12 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelPruning +from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel class TestModel(BoringModel): - validation_step = None test_step = None def __init__(self): @@ -216,14 +215,15 @@ def apply_lottery_ticket_hypothesis(self): assert pruning.lth_calls == trainer.max_epochs // 2 +@pytest.mark.parametrize("make_pruning_permanent", (False, True)) @mock.patch.dict(os.environ, {}, clear=True) -def test_multiple_pruning_callbacks(tmpdir, caplog): +def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): seed_everything(0) model = TestModel() pruning_kwargs = { 'parameters_to_prune': [(model.layer.mlp_1, "weight")], - 'make_pruning_permanent': False, - 'verbose': True + 'verbose': True, + "make_pruning_permanent": make_pruning_permanent } p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs) p2 = ModelPruning("random_unstructured", amount=0.25, apply_pruning=lambda e: e % 2, **pruning_kwargs) @@ -238,17 +238,26 @@ def test_multiple_pruning_callbacks(tmpdir, caplog): max_epochs=5, callbacks=[p1, p2], ) - with caplog.at_level(INFO): trainer.fit(model) actual = [m.strip() for m in caplog.messages[-5:]] layer = "Linear(in_features=32, out_features=32, bias=True)" expected = [ - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 0.00% -> 50.00%", - f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned 50.00% -> 62.50%", - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 62.50% -> 81.25%", - f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned 81.25% -> 85.94%", - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned 85.94% -> 92.97%" + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 0 (0.00%) -> 512 (50.00%)", + f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned: 512 (50.00%) -> 640 (62.50%)", + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 640 (62.50%) -> 832 (81.25%)", + f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned: 832 (81.25%) -> 880 (85.94%)", + f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 880 (85.94%) -> 952 (92.97%)" ] assert actual == expected + + filepath = str(tmpdir / "foo.ckpt") + trainer.save_checkpoint(filepath) + + if not make_pruning_permanent: + # can't reload checkpoints where pruning wasn't made permanent + with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict'): + model.load_from_checkpoint(filepath) + return + assert not hasattr(model.layer.mlp_1, "weight_orig") From 24924002a9cc50b5652e51f95e43093733fdfab0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 07:13:11 +0100 Subject: [PATCH 15/27] Fix docstring --- pytorch_lightning/callbacks/pruning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 7339e55f0c18e..21a4088a092ef 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -354,8 +354,8 @@ def sanitize_parameters_to_prune( parameter_names: Optional[List[str]] = None, ) -> _PARAM_LIST: """ - This function is responsible to check provided ``parameters_to_prune` and `parameters`. - If parameters_to_prune is None, parameters_to_prune will be generated from all parameters of the model. + This function is responsible of sanitizing `parameters_to_prune` and `parameter_names`. + If `parameters_to_prune=None`, it will be generated with all parameters of the model. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES From 607d6827982615d7c0507d2eb859a26ec579f072 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 07:53:00 +0100 Subject: [PATCH 16/27] Minor changes --- pytorch_lightning/callbacks/pruning.py | 17 ++++++++++------- tests/callbacks/test_pruning.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 21a4088a092ef..01035b71a6e4b 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -70,7 +70,7 @@ def __init__( Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning networks parameters during training. - Find here the PyTorch (Pruning Tutorial)[https://pytorch.org/tutorials/intermediate/pruning_tutorial.html] + Learn more with the PyTorch pruning tutorial (https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) .. code-block:: python @@ -169,7 +169,7 @@ def __init__( pruning_kwargs["n"] = pruning_norm pruning_kwargs["dim"] = pruning_dim pruning_fn = self._create_pruning_fn(pruning_fn, **pruning_kwargs) - elif self.is_pruning_method(pruning_fn): + elif self._is_pruning_method(pruning_fn): if not use_global_unstructured: raise MisconfigurationException( "PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`." @@ -224,6 +224,7 @@ def _wrap_pruning_fn(pruning_fn, **kwargs): return partial(pruning_fn, **kwargs) def make_pruning_permanent(self): + """ Makes ``parameters_to_prune`` current pruning permanent. """ for module, param_name in self._parameters_to_prune: try: pytorch_prune.remove(module, param_name) @@ -249,7 +250,7 @@ def apply_lottery_ticket_hypothesis(self): This function implements the step 4. - The ``resample_parameters`` argument can be used to reset the parameters with a new θ ∼ D_θ + The ``resample_parameters`` argument can be used to reset the parameters with a new θ_j ∼ D_θ """ def copy_param(new, old, name: str) -> None: @@ -283,7 +284,8 @@ def _apply_global_pruning(self, amount: float): self._parameters_to_prune, pruning_method=self.pruning_fn, **self._resolve_global_kwargs(amount) ) - def _get_pruned_stats(self, module: nn.Module, name: str) -> Tuple[int, int]: + @staticmethod + def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: attr = f"{name}_mask" if not hasattr(module, attr): return 0, 1 @@ -291,6 +293,7 @@ def _get_pruned_stats(self, module: nn.Module, name: str) -> Tuple[int, int]: return (mask == 0).sum().item(), mask.numel() def apply_pruning(self, amount: Union[int, float]): + """ Applies pruning to ``parameters_to_prune``. """ if self._verbose: stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] @@ -354,8 +357,8 @@ def sanitize_parameters_to_prune( parameter_names: Optional[List[str]] = None, ) -> _PARAM_LIST: """ - This function is responsible of sanitizing `parameters_to_prune` and `parameter_names`. - If `parameters_to_prune=None`, it will be generated with all parameters of the model. + This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. + If ``parameters_to_prune == None``, it will be generated with all parameters of the model. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES @@ -390,7 +393,7 @@ def sanitize_parameters_to_prune( return parameters_to_prune @staticmethod - def is_pruning_method(method: Any) -> bool: + def _is_pruning_method(method: Any) -> bool: if not inspect.isclass(method): return False return issubclass(method, pytorch_prune.BasePruningMethod) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index d0479e95ac2fd..912e220df954d 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -24,7 +24,7 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning +from pytorch_lightning.callbacks import ModelPruning from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -92,7 +92,7 @@ def train_with_pruning_callback( with pytest.raises(MisconfigurationException, match="is supported with `use_global_unstructured=True`"): ModelPruning(**pruning_kwargs) return - if ModelPruning.is_pruning_method(pruning_fn) and not use_global_unstructured: + if ModelPruning._is_pruning_method(pruning_fn) and not use_global_unstructured: with pytest.raises(MisconfigurationException, match="currently only supported with"): ModelPruning(**pruning_kwargs) return From fb416cd62a3e00a6ab587dc8cfae9dd664b5ab44 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 16:14:31 +0100 Subject: [PATCH 17/27] Test loading non-permanent model --- pytorch_lightning/callbacks/pruning.py | 3 ++- tests/callbacks/test_pruning.py | 9 +++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 01035b71a6e4b..c349ce374f4a3 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -44,7 +44,8 @@ "random_unstructured": pytorch_prune.RandomUnstructured, } -_PARAM_LIST = List[Tuple[nn.Module, str]] +_PARAM_TUPLE = Tuple[nn.Module, str] +_PARAM_LIST = Union[List[_PARAM_TUPLE], Tuple[_PARAM_TUPLE]] _MODULE_CONTAINERS = (LightningModule, nn.Sequential, nn.ModuleList, nn.ModuleDict) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 912e220df954d..acf91757ca7bf 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -255,9 +255,6 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): filepath = str(tmpdir / "foo.ckpt") trainer.save_checkpoint(filepath) - if not make_pruning_permanent: - # can't reload checkpoints where pruning wasn't made permanent - with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict'): - model.load_from_checkpoint(filepath) - return - assert not hasattr(model.layer.mlp_1, "weight_orig") + model.load_from_checkpoint(filepath, strict=False) + has_pruning = hasattr(model.layer.mlp_1, "weight_orig") + assert not has_pruning if make_pruning_permanent else has_pruning From ffb8d4754710ce7473f7e14542af8f1594962b5d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Feb 2021 15:31:58 +0000 Subject: [PATCH 18/27] corrent bugs --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 ++++- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 + pytorch_lightning/trainer/training_loop.py | 2 ++ tests/models/test_tpu.py | 4 ++-- 4 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 6f251eb36985a..943f2fc86c4e8 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -171,6 +171,9 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] + def on_save(self, checkpoint: dict) -> dict: + return checkpoint + def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path @@ -183,7 +186,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(self.lightning_module.state_dict(), last_path) + atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0f516e2b0b046..1f1030d75af29 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -95,6 +95,7 @@ def on_save(self, checkpoint: dict) -> dict: Recommended on XLA Guide: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ + print("on_save") return move_data_to_device(checkpoint, torch.device("cpu")) def broadcast(self, obj: object, src: int = 0) -> object: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c5f9f56a0099a..6fe57348b403a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -544,6 +544,8 @@ def run_training_epoch(self): self.trainer.batch_idx = batch_idx + print("batch_idx") + # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 98a02d730ec9e..65fd88f45a57a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, tpu_cores=8, - limit_train_batches=0.4, - limit_val_batches=0.4, + limit_train_batches=4, + limit_val_batches=4, ) model = EvalModelTemplate() From 9eeade60e7b690250a30e2f3914d559587d9ad0b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Feb 2021 16:49:52 +0100 Subject: [PATCH 19/27] Revert "corrent bugs" This reverts commit ffb8d4754710ce7473f7e14542af8f1594962b5d. --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 +---- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - pytorch_lightning/trainer/training_loop.py | 2 -- tests/models/test_tpu.py | 4 ++-- 4 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 943f2fc86c4e8..6f251eb36985a 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -171,9 +171,6 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def on_save(self, checkpoint: dict) -> dict: - return checkpoint - def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path @@ -186,7 +183,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - atomic_save(self.on_save(self.lightning_module.state_dict()), last_path) + atomic_save(self.lightning_module.state_dict(), last_path) # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 1f1030d75af29..0f516e2b0b046 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -95,7 +95,6 @@ def on_save(self, checkpoint: dict) -> dict: Recommended on XLA Guide: https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors """ - print("on_save") return move_data_to_device(checkpoint, torch.device("cpu")) def broadcast(self, obj: object, src: int = 0) -> object: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 6fe57348b403a..c5f9f56a0099a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -544,8 +544,6 @@ def run_training_epoch(self): self.trainer.batch_idx = batch_idx - print("batch_idx") - # ------------------------------------ # TRAINING_STEP + TRAINING_STEP_END # ------------------------------------ diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 65fd88f45a57a..98a02d730ec9e 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir): progress_bar_refresh_rate=0, max_epochs=1, tpu_cores=8, - limit_train_batches=4, - limit_val_batches=4, + limit_train_batches=0.4, + limit_val_batches=0.4, ) model = EvalModelTemplate() From 0af4469a8737da8b0f26e42a5eeaecc20aeef4f0 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 01:44:15 +0100 Subject: [PATCH 20/27] Add beta warning --- pytorch_lightning/callbacks/pruning.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index c349ce374f4a3..7440011b008cb 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -15,6 +15,9 @@ ModelPruning ^^^^^^^^^^^^ +.. warning:: + + ModelPruning is in beta and subject to change. """ import inspect from copy import deepcopy From ce5c331fefc0ff6cc81a1ee4e5f13df3d47aea35 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 02:39:43 +0100 Subject: [PATCH 21/27] Fix docs --- pytorch_lightning/callbacks/pruning.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 7440011b008cb..e3b4477ffcf49 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -14,10 +14,6 @@ r""" ModelPruning ^^^^^^^^^^^^ - -.. warning:: - - ModelPruning is in beta and subject to change. """ import inspect from copy import deepcopy @@ -74,7 +70,10 @@ def __init__( Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning networks parameters during training. - Learn more with the PyTorch pruning tutorial (https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) + To learn more about pruning with PyTorch, please take a look at + `this tutorial `_. + + .. warning:: ``ModelPruning`` is in beta and subject to change. .. code-block:: python @@ -122,7 +121,7 @@ def __init__( make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved. - use_lottery_ticket_hypothesis: See "The lottery ticket hypothesis" (https://arxiv.org/pdf/1803.03635.pdf): + use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis `_: - ``bool``. Whether to apply it or not. - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. @@ -244,17 +243,17 @@ def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, t trained.data = orig.data.to(trained.device) def apply_lottery_ticket_hypothesis(self): - """ + r""" Lottery ticket hypothesis algorithm (see page 2 of the paper): - 1. Randomly initialize a neural network f(x; θ_0) (where θ_0 ∼ D_θ). - 2. Train the network for j iterations, arriving at parameters θ_j . - 3. Prune p% of the parameters in θ_j, creating a mask m. - 4. Reset the remaining parameters to their values in θ_0, creating the winning ticket f(x; m⊙θ_0). + 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). + 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. + 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. + 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. This function implements the step 4. - The ``resample_parameters`` argument can be used to reset the parameters with a new θ_j ∼ D_θ + The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` """ def copy_param(new, old, name: str) -> None: @@ -362,7 +361,7 @@ def sanitize_parameters_to_prune( ) -> _PARAM_LIST: """ This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. - If ``parameters_to_prune == None``, it will be generated with all parameters of the model. + If ``parameters_to_prune is None``, it will be generated with all parameters of the model. """ parameters = parameter_names or ModelPruning.PARAMETER_NAMES From 89553fb8dd22e86ce1150ccb87468c7dc0c269fa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 03:18:47 +0100 Subject: [PATCH 22/27] 2 verbosity levels --- pytorch_lightning/callbacks/pruning.py | 36 ++++++++++++++++++++------ tests/callbacks/test_pruning.py | 25 +++++++++++------- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index e3b4477ffcf49..a22e3386dc61d 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -64,7 +64,7 @@ def __init__( resample_parameters: bool = False, pruning_dim: Optional[int] = None, pruning_norm: Optional[int] = None, - verbose: bool = False, + verbose: int = 0, ) -> None: """ Model pruning Callback, using PyTorch's prune utilities. @@ -99,7 +99,7 @@ def __init__( pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass. Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details. - parameters_to_prune: List of strings or list of tuples ``(nn.Module, "parameter_name_string")``. + parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``. parameter_names: List of parameter names to be pruned from the nn.Module. Can either be ``"weight"`` or ``"bias"``. @@ -133,7 +133,7 @@ def __init__( pruning_norm: If you are using ``ln_structured`` you need to specify the norm. - verbose: Whether to log pruning percentage changes. + verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity """ @@ -196,10 +196,14 @@ def __init__( if not isinstance(amount, (int, float, Callable)): raise MisconfigurationException( - "amount should be provided and be either an int, a float or Callable function." + "`amount` should be provided and be either an int, a float or Callable function." ) self.amount = amount + + if verbose not in (0, 1, 2): + raise MisconfigurationException("`verbose` must be any of (0, 1, 2)") + self._verbose = verbose def filter_parameters_to_prune(self, parameters_to_prune: Optional[_PARAM_LIST] = None) -> Optional[_PARAM_LIST]: @@ -298,7 +302,7 @@ def _get_pruned_stats(module: nn.Module, name: str) -> Tuple[int, int]: def apply_pruning(self, amount: Union[int, float]): """ Applies pruning to ``parameters_to_prune``. """ if self._verbose: - stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + prev_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] if self._use_global_unstructured: self._apply_global_pruning(amount) @@ -306,11 +310,27 @@ def apply_pruning(self, amount: Union[int, float]): self._apply_local_pruning(amount) if self._verbose: + curr_stats = [self._get_pruned_stats(m, n) for m, n in self._parameters_to_prune] + self._log_sparsity_stats(prev_stats, curr_stats, amount=amount) + + def _log_sparsity_stats( + self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 + ): + total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) + total_prev_pruned = sum(zeroes for zeroes, _ in prev) + total_curr_pruned = sum(zeroes for zeroes, _ in curr) + pruning_fn_name = self.pruning_fn.__name__ + rank_zero_info( + f"Applied `{pruning_fn_name}`. Pruned:" + f" {total_prev_pruned}/{total_params} ({total_prev_pruned / total_params:.2%}) ->" + f" {total_curr_pruned}/{total_params} ({total_curr_pruned / total_params:.2%})" + ) + if self._verbose == 2: for i, (module, name) in enumerate(self._parameters_to_prune): - prev_mask_zeroes, prev_mask_size = stats[i] - curr_mask_zeroes, curr_mask_size = self._get_pruned_stats(module, name) + prev_mask_zeroes, prev_mask_size = prev[i] + curr_mask_zeroes, curr_mask_size = curr[i] rank_zero_info( - f"Applied `{self.pruning_fn.__name__}` to `{module!r}.{name}` with amount={amount}. Pruned:" + f"Applied `{pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" f" {prev_mask_zeroes} ({prev_mask_zeroes / prev_mask_size:.2%}) ->" f" {curr_mask_zeroes} ({curr_mask_zeroes / curr_mask_size:.2%})" ) diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index acf91757ca7bf..2ce555cecc3c3 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -128,6 +128,8 @@ def test_pruning_misconfiguration(): ModelPruning(pruning_fn={}) # noqa with pytest.raises(MisconfigurationException, match="should be provided"): ModelPruning(pruning_fn="random_structured") + with pytest.raises(MisconfigurationException, match=r"must be any of \(0, 1, 2\)"): + ModelPruning(pruning_fn="l1_unstructured", verbose=3) with pytest.raises(MisconfigurationException, match="requesting `ln_structured` pruning, the `pruning_norm`"): ModelPruning(pruning_fn="ln_structured", pruning_dim=0) @@ -221,8 +223,8 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): seed_everything(0) model = TestModel() pruning_kwargs = { - 'parameters_to_prune': [(model.layer.mlp_1, "weight")], - 'verbose': True, + 'parameters_to_prune': [(model.layer.mlp_1, "weight"), (model.layer.mlp_3, "weight")], + 'verbose': 2, "make_pruning_permanent": make_pruning_permanent } p1 = ModelPruning("l1_unstructured", amount=0.5, apply_pruning=lambda e: not e % 2, **pruning_kwargs) @@ -235,20 +237,23 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): logger=False, limit_train_batches=10, limit_val_batches=2, - max_epochs=5, + max_epochs=3, callbacks=[p1, p2], ) with caplog.at_level(INFO): trainer.fit(model) - actual = [m.strip() for m in caplog.messages[-5:]] - layer = "Linear(in_features=32, out_features=32, bias=True)" + actual = [m.strip() for m in caplog.messages[-9:]] expected = [ - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 0 (0.00%) -> 512 (50.00%)", - f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned: 512 (50.00%) -> 640 (62.50%)", - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 640 (62.50%) -> 832 (81.25%)", - f"Applied `RandomUnstructured` to `{layer}.weight` with amount=0.25. Pruned: 832 (81.25%) -> 880 (85.94%)", - f"Applied `L1Unstructured` to `{layer}.weight` with amount=0.5. Pruned: 880 (85.94%) -> 952 (92.97%)" + "Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)", + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501 + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501 + "Applied `RandomUnstructured`. Pruned: 544/1122 (48.48%) -> 680/1122 (60.61%)", + "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.25. Pruned: 506 (49.41%) -> 633 (61.82%)", # noqa: E501 + "Applied `RandomUnstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.25. Pruned: 38 (59.38%) -> 47 (73.44%)", # noqa: E501 + "Applied `L1Unstructured`. Pruned: 680/1122 (60.61%) -> 884/1122 (78.79%)", + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501 + "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501 ] assert actual == expected From 120417caed82a810b0ea0db0561a457338fedc0a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 03:21:43 +0100 Subject: [PATCH 23/27] OCD --- pytorch_lightning/callbacks/pruning.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index a22e3386dc61d..1cdf5ee322e17 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -317,22 +317,22 @@ def _log_sparsity_stats( self, prev: List[Tuple[int, int]], curr: List[Tuple[int, int]], amount: Union[int, float] = 0 ): total_params = sum(p.numel() for layer, _ in self._parameters_to_prune for p in layer.parameters()) - total_prev_pruned = sum(zeroes for zeroes, _ in prev) - total_curr_pruned = sum(zeroes for zeroes, _ in curr) + prev_total_zeros = sum(zeros for zeros, _ in prev) + curr_total_zeros = sum(zeros for zeros, _ in curr) pruning_fn_name = self.pruning_fn.__name__ rank_zero_info( f"Applied `{pruning_fn_name}`. Pruned:" - f" {total_prev_pruned}/{total_params} ({total_prev_pruned / total_params:.2%}) ->" - f" {total_curr_pruned}/{total_params} ({total_curr_pruned / total_params:.2%})" + f" {prev_total_zeros}/{total_params} ({prev_total_zeros / total_params:.2%}) ->" + f" {curr_total_zeros}/{total_params} ({curr_total_zeros / total_params:.2%})" ) if self._verbose == 2: for i, (module, name) in enumerate(self._parameters_to_prune): - prev_mask_zeroes, prev_mask_size = prev[i] - curr_mask_zeroes, curr_mask_size = curr[i] + prev_mask_zeros, prev_mask_size = prev[i] + curr_mask_zeros, curr_mask_size = curr[i] rank_zero_info( f"Applied `{pruning_fn_name}` to `{module!r}.{name}` with amount={amount}. Pruned:" - f" {prev_mask_zeroes} ({prev_mask_zeroes / prev_mask_size:.2%}) ->" - f" {curr_mask_zeroes} ({curr_mask_zeroes / curr_mask_size:.2%})" + f" {prev_mask_zeros} ({prev_mask_zeros / prev_mask_size:.2%}) ->" + f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) def on_before_accelerator_backend_setup(self, trainer, pl_module): From 7ccfdb949ea96821b80801867e96359749b18f3c Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 10 Feb 2021 14:06:01 +0000 Subject: [PATCH 24/27] Fix formatting --- pytorch_lightning/callbacks/pruning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 1cdf5ee322e17..2ab09ccda2488 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -253,11 +253,13 @@ def apply_lottery_ticket_hypothesis(self): 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. - 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. + 4. Reset the remaining parameters to their values in :math:`\theta_0`, + creating the winning ticket :math:`f(x; m \odot \theta_0)`. This function implements the step 4. - The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` + The ``resample_parameters`` argument can be used to reset the parameters with a new + :math:`\theta_z \sim \mathcal{D}_\theta` """ def copy_param(new, old, name: str) -> None: From 87b6cebb5d45fdf930139be4acef277f6019d7c2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 15:15:04 +0100 Subject: [PATCH 25/27] Revert "Fix formatting" This reverts commit 7ccfdb949ea96821b80801867e96359749b18f3c. --- pytorch_lightning/callbacks/pruning.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 2ab09ccda2488..1cdf5ee322e17 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -253,13 +253,11 @@ def apply_lottery_ticket_hypothesis(self): 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. - 4. Reset the remaining parameters to their values in :math:`\theta_0`, - creating the winning ticket :math:`f(x; m \odot \theta_0)`. + 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. This function implements the step 4. - The ``resample_parameters`` argument can be used to reset the parameters with a new - :math:`\theta_z \sim \mathcal{D}_\theta` + The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` """ def copy_param(new, old, name: str) -> None: From 79400365cffd79c3d57eb330f34bf529691c174d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 10 Feb 2021 15:21:56 +0100 Subject: [PATCH 26/27] Ignore E501. Otherwise math blocks don't render properly --- pyproject.toml | 2 +- pytorch_lightning/callbacks/pruning.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 331e247839145..5caf8a48648b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ [tool.autopep8] max_line_length = 120 -ignore = ["W503", "W504", "E402", "E731", "C40", "E741", "F40", "F841"] +ignore = ["W503", "E402", "E731", "C40", "E741", "F40", "F841"] [tool.isort] known_first_party = [ diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 1cdf5ee322e17..253cd0bbc4786 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -258,7 +258,7 @@ def apply_lottery_ticket_hypothesis(self): This function implements the step 4. The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` - """ + """ # noqa: E501 def copy_param(new, old, name: str) -> None: dst = getattr(new, name) From c7ed73f4d71e6159872cc496effca0eddd7646b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 10 Feb 2021 15:42:38 +0100 Subject: [PATCH 27/27] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5caf8a48648b2..331e247839145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ [tool.autopep8] max_line_length = 120 -ignore = ["W503", "E402", "E731", "C40", "E741", "F40", "F841"] +ignore = ["W503", "W504", "E402", "E731", "C40", "E741", "F40", "F841"] [tool.isort] known_first_party = [