Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiGPU training + changes to Checkpointing logic #218

Merged
merged 68 commits into from
May 25, 2023
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
6627ae6
strategy maker
prabhuteja12 Apr 24, 2023
67ac54c
strategy maker with license
prabhuteja12 Apr 24, 2023
3bf62aa
Merge branch 'main' into enable_multigpu
prabhuteja12 Apr 24, 2023
e965bb2
tweaks to strategy maker
prabhuteja12 Apr 24, 2023
45bdd22
interfaces changes to allow precision, strategy
prabhuteja12 Apr 24, 2023
375944a
adding defaults
prabhuteja12 Apr 24, 2023
1404843
argparse additions
prabhuteja12 Apr 24, 2023
557a8e5
nlp example to local
prabhuteja12 Apr 24, 2023
f51aae2
precision, utils, rank_zero prints
prabhuteja12 May 1, 2023
5eecf14
simplified delete folder on rank zero
prabhuteja12 May 3, 2023
89766fa
moving loss to learner
prabhuteja12 May 3, 2023
671cada
Merge remote-tracking branch 'origin/dev' into checkpointing
prabhuteja12 May 4, 2023
cb4b10d
distributed strategy flags
prabhuteja12 May 4, 2023
d83374f
utils cleanup
prabhuteja12 May 4, 2023
5f20691
adding strategy and precision to exps
prabhuteja12 May 4, 2023
3a60ca8
loss fn in experiment config
prabhuteja12 May 4, 2023
fd61b3f
converting lambda to function for mp spawn
prabhuteja12 May 4, 2023
0c9bebc
changing state dicts to checkpoint management
prabhuteja12 May 4, 2023
eda2891
replace manual saving by ModelCheckpoint
prabhuteja12 May 4, 2023
9d22177
eval with only one device
prabhuteja12 May 6, 2023
0ac6fad
adding device, strategy
prabhuteja12 May 6, 2023
2103b73
merging dev
prabhuteja12 May 6, 2023
d97685f
hyperparameter exceptions to checkpointing
prabhuteja12 May 9, 2023
14fa03c
hyperparameter exceptions to checkpointing
prabhuteja12 May 9, 2023
29b3f98
pickling extra args for deepspeed compat
prabhuteja12 May 10, 2023
7fda217
loss_fn tests and misc
prabhuteja12 May 10, 2023
b17ac3d
avalanche changes to enable device, loss, prec
prabhuteja12 May 16, 2023
c0d9e59
changes to tests
prabhuteja12 May 16, 2023
e8ba6ee
checkpoint callback, cleanup
prabhuteja12 May 16, 2023
48a2030
deepspeed utlitiy to unshard
prabhuteja12 May 16, 2023
baeb4b4
conftest updates
prabhuteja12 May 16, 2023
b2f78b2
more cleanup
prabhuteja12 May 16, 2023
5e18c22
Merge branch 'dev' of https://github.com/awslabs/Renate into checkpoi…
prabhuteja12 May 16, 2023
a921cb7
deepspeed to requirements
prabhuteja12 May 17, 2023
d231edd
doc building fixes
prabhuteja12 May 17, 2023
3177905
linting changes
prabhuteja12 May 17, 2023
43b6db8
single device integration test
prabhuteja12 May 17, 2023
9b3bc1f
temporary fix as in #236
prabhuteja12 May 21, 2023
1013b50
lint changes
prabhuteja12 May 21, 2023
df40dde
loss in config_scenario
prabhuteja12 May 21, 2023
b769aed
single device test
prabhuteja12 May 21, 2023
5576a87
linting changes
prabhuteja12 May 21, 2023
fc02ce8
flake8 changes
prabhuteja12 May 22, 2023
6f8158a
remove commented code
prabhuteja12 May 22, 2023
67c33e3
increasing maxtime
prabhuteja12 May 22, 2023
5755331
increasing maxtime
prabhuteja12 May 22, 2023
06b2735
max time in integration tests
prabhuteja12 May 22, 2023
adb34a5
checkpoint loading fix
prabhuteja12 May 22, 2023
d8abdf1
undoing changes in nlp example
prabhuteja12 May 22, 2023
5e4b16d
documentation line number changes
prabhuteja12 May 22, 2023
99f0174
addressing misc comments
prabhuteja12 May 22, 2023
359a7c7
reducing max_time
prabhuteja12 May 22, 2023
f9f1ecb
fixing linting erros
prabhuteja12 May 22, 2023
de02ae2
addressing comments
prabhuteja12 May 22, 2023
66e5cec
doc changes for loss_fn
prabhuteja12 May 23, 2023
e49da0a
nlp documentation
prabhuteja12 May 23, 2023
a5f72bb
reorganizing utils funcs
prabhuteja12 May 23, 2023
4547dd3
removing post init and set_ funcs, tests
prabhuteja12 May 23, 2023
10f65fc
flake8
prabhuteja12 May 23, 2023
6f9a515
Merge branch 'dev' into checkpointing
prabhuteja12 May 23, 2023
d1337a1
increasing max time
prabhuteja12 May 23, 2023
98055bd
simplifying deletion and rank zero
prabhuteja12 May 23, 2023
02c89f6
removing mentions of loss fn in docstrings of models
prabhuteja12 May 25, 2023
c4051d3
removing save hyperparam call in child learners
prabhuteja12 May 25, 2023
3291d7a
detailed docstrings for model checkpoint callback
prabhuteja12 May 25, 2023
1eea37a
Merge branch 'dev' into checkpointing
prabhuteja12 May 25, 2023
757bf8e
sphinx build errors for teardown
prabhuteja12 May 25, 2023
c486c4b
Merge branch 'checkpointing' of github.com:awslabs/Renate into checkp…
prabhuteja12 May 25, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions doc/examples/nlp_finetuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ data module expects the name of a dataset as well as a tokenizer. Here, we load
dataset in the first training stage (:code:`chunk_id = 0`) and the :code:`"rotten_tomatoes"` dataset
for the subsequent model update (:code:`chunk_id = 1`).

The function :code:`loss_fn` defines the appropriate loss criterion. As this is a classification
problem we use :code:`torch.nn.CrossEntropyLoss`.

The data module will return pre-tokenized data and no further transforms are needed in this case.

.. literalinclude:: ../../examples/nlp_finetuning/renate_config.py
Expand All @@ -35,4 +38,36 @@ on this see previous examples or :doc:`../getting_started/how_to_run_training`.
:lines: 3-


Support for training large models
---------------------------------

To support training methods for larger models, we expose two arguments in the
:code:`run_experiment_job` to enable training on multiple GPUs. For this we exploit the
strategy functionality provided by `Lightning`
`large model tutorial <https://lightning.ai/docs/pytorch/stable/advanced/model_parallel.html>`_ and
`documentation <https://lightning.ai/docs/pytorch/stable/extensions/strategy.html>`_. Currently, we
support
the strategies:

* `"ddp_find_unused_parameters_false"`
* `"ddp"`
* `"deepspeed"`
* `"deepspeed_stage_1"`
* `"deepspeed_stage_2"`
* `"deepspeed_stage_2_offload"`
* `"deepspeed_stage_3"`
* `"deepspeed_stage_3_offload"`
* `"deepspeed_stage_3_offload_nvme"`

These can be enabled by passing one of the above options to :code:`strategy`. The number of devices
to be used for parallel training can be specified using :code:`devices` argument which defaults to
`1`. We also support lower precision training by passing the :code:`precision` argument which
accepts the options `"16"`, `"32"`, `"64"`, `"bf16"`. Note that it has to be a string and not the
integer `32`. `bf16` is restricted to newer hardware and thus need slightly more attention before
using it.

See last four lines in the previous code example.

.. literalinclude:: ../../examples/nlp_finetuning/start.py
:lines: 47-49

23 changes: 20 additions & 3 deletions doc/getting_started/how_to_renate_config.rst
prabhuteja12 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`.
return model


Loss Definition
================

This function returns a :code:`torch.nn.Module` object that computes the loss with the
signature

.. code-block:: python

def loss_fn() -> torch.nn.Module:

An example of this for the task of MNIST classfication above as

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Loss function example
:lines: 95-96


Data Preparation
================

Expand All @@ -67,7 +84,7 @@ such as data subsampling or splitting.

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 43-66
:lines: 41-68

Transforms
==========
Expand Down Expand Up @@ -112,7 +129,7 @@ These are optional as well but, if omitted, Renate will use :code:`train_transfo

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 73-90
:lines: 71-78

Custom Metrics
==============
Expand All @@ -124,7 +141,7 @@ or created ad-hoc by implementing the same interface

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 93-
:lines: 91-
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other line numbers don't require changes? does the doc page still look ok?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main one yes.


To enable the usage of additional metrics in Renate it is sufficient to implement the
:code:`metrics_fn` function, returning a dictionary where the key is a string containing the
Expand Down
10 changes: 6 additions & 4 deletions examples/getting_started/renate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@

class MyMNISTMLP(RenateModule):
def __init__(self, num_hidden: int) -> None:
# Model hyperparameters as well as the loss function need to registered via RenateModule's
# Model hyperparameters need to registered via RenateModule's
# constructor, see documentation. Otherwise, this is a standard torch model.
super().__init__(
constructor_arguments={"num_hidden": num_hidden}, loss_fn=torch.nn.CrossEntropyLoss()
)
super().__init__(constructor_arguments={"num_hidden": num_hidden})
self._fc1 = torch.nn.Linear(28 * 28, num_hidden)
self._fc2 = torch.nn.Linear(num_hidden, 10)

Expand Down Expand Up @@ -92,3 +90,7 @@ def buffer_transform() -> Callable:

def metrics_fn() -> Dict:
return {"my_accuracy": Accuracy()}


def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss()
6 changes: 5 additions & 1 deletion examples/nlp_finetuning/renate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@ def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, return_dict=False
)
model = RenateWrapper(transformer_model, loss_fn=torch.nn.CrossEntropyLoss())
model = RenateWrapper(transformer_model)
if model_state_url is not None:
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model


def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss()


def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule:
"""Returns one of two movie review datasets depending on `chunk_id`."""
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp_finetuning/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@
instance_count=1,
instance_type="ml.g4dn.xlarge",
job_name="renate-training-nlp-finetuning",
devices=1,
strategy="deepspeed_stage_2",
prabhuteja12 marked this conversation as resolved.
Show resolved Hide resolved
precision="32",
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Pillow>=9.0, <9.5.1
tabulate>=0.9.0, <0.9.1
torchmetrics~=0.10.3
torchvision>=0.13.0, <0.15.2
deepspeed==0.9.1
datasets>=2.9.0, < 2.12.1
transformers>4.23.0, <4.29.3
scipy>=1.9.0, <1.10.2
8 changes: 6 additions & 2 deletions src/renate/benchmark/datasets/vision_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,21 @@ def setup(self) -> None:
self._data_path,
train=True,
transform=transforms.ToTensor(),
target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
target_transform=transforms.Lambda(to_long),
)
self._train_data, self._val_data = self._split_train_val_data(train_data)
self._test_data = cls(
self._data_path,
train=False,
transform=transforms.ToTensor(),
target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)),
target_transform=transforms.Lambda(to_long),
)


def to_long(x):
return torch.tensor(x, dtype=torch.long)


class CLEARDataModule(RenateDataModule):
"""Datamodule that process CLEAR datasets: CLEAR10 and CLEAR100.

Expand Down
4 changes: 4 additions & 0 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def get_scenario(
raise ValueError(f"Unknown scenario `{scenario_name}`.")


def loss_fn() -> torch.nn.Module:
return torch.nn.CrossEntropyLoss()


def data_module_fn(
data_path: str,
chunk_id: int,
Expand Down
19 changes: 17 additions & 2 deletions src/renate/benchmark/experimentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def execute_experiment_job(
devices: int = defaults.DEVICES,
deterministic_trainer: bool = True,
job_name: str = defaults.JOB_NAME,
strategy: str = defaults.DISTRIBUTED_STRATEGY,
precision: str = defaults.PRECISION,
) -> None:
"""Executes the experiment job.

Expand Down Expand Up @@ -202,6 +204,8 @@ def execute_experiment_job(
devices=devices,
deterministic_trainer=deterministic_trainer,
seed=seed,
strategy=strategy,
precision=precision,
)
_execute_experiment_job_remotely(
job_name=job_name,
Expand All @@ -226,6 +230,8 @@ def execute_experiment_job(
instance_type=instance_type,
instance_count=instance_count,
instance_max_time=instance_max_time,
strategy=strategy,
precision=precision,
)


Expand All @@ -246,6 +252,8 @@ def _execute_experiment_job_locally(
max_num_trials_finished: int,
n_workers: int,
deterministic_trainer: bool,
strategy: str,
precision: str,
) -> None:
"""Runs an experiment, combining hyperparameter tuning and model for multiple updates.

Expand Down Expand Up @@ -291,6 +299,9 @@ def _execute_experiment_job_locally(
model_url,
)

# TODO: evaluate's trainer has to use devices=1:
# See https://github.com/Lightning-AI/lightning/issues/2537
# The fix is to launch evaluation in a seperate process like training.
results: Dict[str, List[List[float]]] = {}
evaluate_and_record_results(
results,
Expand All @@ -301,7 +312,9 @@ def _execute_experiment_job_locally(
logged_metrics=metrics,
metric_postfix="_init",
accelerator=accelerator,
devices=devices,
devices=1,
strategy=strategy,
precision=precision,
)

for update_id in range(num_updates):
Expand Down Expand Up @@ -329,6 +342,8 @@ def _execute_experiment_job_locally(
seed=seed,
accelerator=accelerator,
devices=devices,
precision=precision,
strategy=strategy,
deterministic_trainer=deterministic_trainer,
)
move_to_uri(output_state_url, input_state_url)
Expand All @@ -347,7 +362,7 @@ def _execute_experiment_job_locally(
target_transform=transforms.get("target_test_transform"),
logged_metrics=metrics,
accelerator=accelerator,
devices=devices,
devices=1,
)
df = individual_metrics_summary(results, update_id + 1, num_updates)
save_pandas_df_to_csv(
Expand Down
30 changes: 19 additions & 11 deletions src/renate/benchmark/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from renate import defaults
from renate.models import RenateModule
from renate.models.prediction_strategies import ICaRLClassificationStrategy, PredictionStrategy
from renate.utils.deepspeed import convert_to_tensor, recover_object_from_tensor


# TODO: merge unit tests for the submodules
Expand All @@ -23,7 +24,6 @@ class RenateBenchmarkingModule(RenateModule, ABC):
embedding_size: Representation size of the model after the backbone.
num_outputs: The number of outputs of the model.
constructor_arguments: Arguments needed to instantiate the model.
loss_fn: The loss function to be optimized during the training.
prediction_strategy: By default a forward pass through the model. Some ModelUpdater must
be combined with specific prediction strategies to work as intended.
add_icarl_class_means: Specific parameters for iCaRL. Can be set to ``False`` if any other
Expand All @@ -35,15 +35,13 @@ def __init__(
embedding_size: int,
num_outputs: int,
constructor_arguments: dict,
loss_fn: torch.nn.Module,
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
):
constructor_arguments["num_outputs"] = num_outputs
constructor_arguments["add_icarl_class_means"] = add_icarl_class_means
super().__init__(
constructor_arguments=constructor_arguments,
loss_fn=loss_fn,
)
self._embedding_size = embedding_size
self._num_outputs = num_outputs
Expand Down Expand Up @@ -83,13 +81,23 @@ def get_params(self, task_id: str = defaults.TASK_ID) -> List[torch.nn.Parameter
self.get_predictor(task_id=task_id).parameters()
)

def get_extra_state(self) -> Any:
"""Get the constructor_arguments, loss and task ids necessary to reconstruct the model."""
extra_state = super().get_extra_state()
def get_extra_state(self, encode=True) -> Any:
"""Get the constructor_arguments and task ids necessary to reconstruct the model.

Encode converts the state into a torch tensor so that Deepspeed serialization works.
We don't encode any of the super() calls, but encode only the final dict.
"""
extra_state = super().get_extra_state(encode=not encode)
extra_state["prediction_strategy"] = self._prediction_strategy
return extra_state
return convert_to_tensor(extra_state) if encode else extra_state

def set_extra_state(self, state: Any, decode=True):
"""Extract the content of the ``_extra_state`` and set the related values in the module.

def set_extra_state(self, state: Any):
"""Extract the content of the ``_extra_state`` and set the related values in the module."""
super().set_extra_state(state)
self._prediction_strategy = state["prediction_strategy"]
decode flag is to decode the tensor of pkl bytes."""
super().set_extra_state(state, decode=decode)
self._prediction_strategy = (
recover_object_from_tensor(state)["prediction_strategy"]
if decode
else state["prediction_strategy"]
)
3 changes: 0 additions & 3 deletions src/renate/benchmark/models/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class MultiLayerPerceptron(RenateBenchmarkingModule):
num_hidden_layers: Number of hidden layers.
hidden_size: Uniform hidden size or the list or tuple of hidden sizes for individual hidden
layers.
loss: Loss function to be used for training.
activation: Activation name, matching activation name in `torch.nn` to be used between the
hidden layers.
batch_normalization: Whether to use Batch Normalization after the activation. By default the
Expand All @@ -35,7 +34,6 @@ def __init__(
num_outputs: int,
num_hidden_layers: int,
hidden_size: Union[int, List[int], Tuple[int]],
loss: nn.Module = nn.CrossEntropyLoss(),
activation: str = "ReLU",
batch_normalization: bool = False,
prediction_strategy: Optional[PredictionStrategy] = None,
Expand All @@ -52,7 +50,6 @@ def __init__(
"activation": activation,
"batch_normalization": batch_normalization,
},
loss_fn=loss,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
Expand Down
3 changes: 0 additions & 3 deletions src/renate/benchmark/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class ResNet(RenateBenchmarkingModule):
norm_layer: What kind of normalization layer to use, following convolutions.
cifar_stem: Whether to use a stem for CIFAR-sized images.
gray_scale: Whether input images are gray-scale images, i.e. only 1 color channel.
loss: Loss function to be used for training.
prediction_strategy: Continual learning strategies may alter the prediction at train or test
time.
add_icarl_class_means: If ``True``, additional parameters used only by the
Expand All @@ -48,7 +47,6 @@ def __init__(
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
cifar_stem: bool = True,
gray_scale: bool = False,
loss: nn.Module = nn.CrossEntropyLoss(),
prediction_strategy: Optional[PredictionStrategy] = None,
add_icarl_class_means: bool = True,
) -> None:
Expand Down Expand Up @@ -76,7 +74,6 @@ def __init__(
"cifar_stem": cifar_stem,
"gray_scale": gray_scale,
},
loss_fn=loss,
prediction_strategy=prediction_strategy,
add_icarl_class_means=add_icarl_class_means,
)
Expand Down
Loading