Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quant as optional step #8464

Merged
merged 22 commits into from
Jul 22, 2021
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled traditional/manual launching of DDP processes through `LOCAL_RANK` and `NODE_RANK` environment variable assignments ([#7480](https://github.com/PyTorchLightning/pytorch-lightning/pull/7480))


- Added `quantize_on_fit_end` argument to `QuantizationAwareTraining` ([#8464](https://github.com/PyTorchLightning/pytorch-lightning/pull/8464))


- Added experimental support for loop specialization ([#8226](https://github.com/PyTorchLightning/pytorch-lightning/pull/8226))


- Added support for `devices` flag to Trainer ([#8440](https://github.com/PyTorchLightning/pytorch-lightning/pull/8440))



### Changed


Expand Down
87 changes: 48 additions & 39 deletions pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,60 +82,65 @@ def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:


class QuantizationAwareTraining(Callback):
OBSERVER_TYPES = ('histogram', 'average')
"""
Borda marked this conversation as resolved.
Show resolved Hide resolved
Quantization allows speeding up inference and decreasing memory requirements
by performing computations and storing tensors at lower bitwidths
(such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.

.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.

def __init__(
self,
qconfig: Union[str, QConfig] = 'fbgemm',
observer_type: str = "average",
collect_quantization: Optional[Union[int, Callable]] = None,
modules_to_fuse: Optional[Sequence] = None,
input_compatible: bool = True,
) -> None:
"""
Quantization allows speeding up inference and decreasing memory requirements
by performing computations and storing tensors at lower bitwidths
(such as INT8 or FLOAT16) than floating point precision.
We use native PyTorch API so for more information
see `Quantization <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_.

.. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
Args:

qconfig: quantization configuration:

Args:
- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- a custom `torch.quantization.QConfig
<https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.

qconfig: quantization configuration:
observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive.

- 'fbgemm' for server inference.
- 'qnnpack' for mobile inference.
- a custom `torch.quantization.QConfig <https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig>`_.
collect_quantization: count or custom function to collect quantization statistics:

observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
and ``HistogramObserver`` as "histogram" which is more computationally expensive.
- ``None`` (deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument.
See this example to trigger only the last epoch:

collect_quantization: count or custom function to collect quantization statistics:
.. code-block:: python
Borda marked this conversation as resolved.
Show resolved Hide resolved

- ``None`` (deafult). The quantization observer is called in each module forward
(useful for collecting extended statistic when useing image/data augmentation).
- ``int``. Use to set a fixed number of calls, starting from the beginning.
- ``Callable``. Custom function with single trainer argument.
See this example to trigger only the last epoch:
def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)

.. code-block:: python
QuantizationAwareTraining(collect_quantization=custom_trigger_last)

def custom_trigger_last(trainer):
return trainer.current_epoch == (trainer.max_epochs - 1)
modules_to_fuse: allows you fuse a few layers together as shown in
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.

QuantizationAwareTraining(collect_quantization=custom_trigger_last)
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript and export with ``torch.save``.

modules_to_fuse: allows you fuse a few layers together as shown in
`diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
quantize_on_fit_end: perform the quantization in `on_fit_end`.
Note that once converted, the model cannot be put in training mode again.

input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript.
"""
OBSERVER_TYPES = ('histogram', 'average')

""" # noqa: E501
def __init__(
self,
qconfig: Union[str, QConfig] = 'fbgemm',
observer_type: str = "average",
collect_quantization: Optional[Union[int, Callable]] = None,
modules_to_fuse: Optional[Sequence] = None,
input_compatible: bool = True,
quantize_on_fit_end: bool = True,
) -> None:
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
raise MisconfigurationException(
Expand All @@ -157,6 +162,7 @@ def custom_trigger_last(trainer):

self.modules_to_fuse = modules_to_fuse
self._input_compatible = input_compatible
self._convert_on_fit_end = quantize_on_fit_end
self._forward_calls = 0

def _check_feasible_fuse(self, model):
Expand Down Expand Up @@ -199,6 +205,9 @@ def on_fit_start(self, trainer, pl_module):
torch.quantization.prepare_qat(pl_module, inplace=True)

def on_fit_end(self, trainer, pl_module):
if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward
return
pl_module.eval()
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,8 +1227,10 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
return output

def _parse_devices(
self, gpus: Optional[Union[List[int], str, int]], auto_select_gpus: bool, tpu_cores: Optional[Union[List[int],
str, int]]
self,
gpus: Optional[Union[List[int], str, int]],
auto_select_gpus: bool,
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[List[int]], Optional[Union[List[int], int]]]:
if auto_select_gpus and isinstance(gpus, int):
gpus = pick_multiple_gpus(gpus)
Expand Down
28 changes: 21 additions & 7 deletions tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@

@pytest.mark.parametrize("observe", ['average', 'histogram'])
@pytest.mark.parametrize("fuse", [True, False])
@pytest.mark.parametrize("convert", [True, False])
@RunIf(quantization=True)
def test_quantization(tmpdir, observe: str, fuse: bool):
def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
"""Parity test for quant model"""
seed_everything(42)
dm = RegressDataModule()
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=10,
gpus=1 if torch.cuda.is_available() else None,
max_epochs=7,
gpus=int(torch.cuda.is_available()),
)
model = RegressionModel()
qmodel = copy.deepcopy(model)
Expand All @@ -47,20 +48,33 @@ def test_quantization(tmpdir, observe: str, fuse: bool):
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))

fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers)
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
trainer = Trainer(callbacks=[qcb], **trainer_args)
trainer.fit(qmodel, datamodule=dm)

quant_calls = qcb._forward_calls
assert quant_calls == qcb._forward_calls
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
# test that the test score is almost the same as with pure training
assert torch.allclose(org_score, quant_score, atol=0.45)
model_path = trainer.checkpoint_callback.best_model_path

trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
if not convert:
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
trainer.fit(qmodel, datamodule=dm)
qmodel.eval()
torch.quantization.convert(qmodel, inplace=True)

quant_size = qmodel.model_size
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
# test that the trained model is smaller then initial
size_ratio = quant_size / org_size
assert size_ratio < 0.65
# test that the test score is almost the same as with pure training
assert torch.allclose(org_score, quant_score, atol=0.45)

# todo: make it work also with strict loading
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
assert torch.allclose(org_score, quant2_score, atol=0.45)


@RunIf(quantization=True)
Expand Down