Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables reload of dataloaders on every n epochs from every epoch #5043

Merged
merged 87 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
82dea77
edit arg to reload_dataloaders_every_n_epoch
sidhantls Dec 9, 2020
6ea2892
init reload_dataloaders_every_n_epoch
sidhantls Dec 9, 2020
9ba83c2
edit logic to reload dl
sidhantls Dec 9, 2020
65e837a
update arg to test datamodule
sidhantls Dec 9, 2020
171c404
update arg test dataloader
sidhantls Dec 9, 2020
e0f67a0
edit reload dl logic in eval loop
sidhantls Dec 9, 2020
e4f9275
fix var name in reset_train_val_dataloaders
sidhantls Dec 9, 2020
6f6e4af
fix error, use current_epoch attribute
sidhantls Dec 9, 2020
4a60ace
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
9c0cb47
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
3dd0d86
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
30ae4f8
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
3010d7c
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
fb1ce68
edit every_n_epoch to every_n_epochs
sidhantls Dec 11, 2020
8301ca8
assert reload_dataloaders_every_n_epochs positive
sidhantls Dec 11, 2020
adabdac
assert reload_dataloaders_every_n_epochs positive
sidhantls Dec 11, 2020
55b414b
add trainer property should reload dl
sidhantls Dec 11, 2020
dd3b733
update should reload dl in train loop
sidhantls Dec 11, 2020
290412e
condition on should reload dl in eval loop
sidhantls Dec 11, 2020
fdec509
pep8
sidhantls Dec 11, 2020
f9e986e
fix update should reload dl in train loop
sidhantls Dec 12, 2020
7d44cd7
add test case
sidhantls Dec 12, 2020
e4a01da
replace assertion with misconfig exception
sidhantls Dec 13, 2020
37c7cc4
remove unused variable
sidhantls Dec 13, 2020
0ec40ae
remove unnecessary checks
sidhantls Dec 13, 2020
8181684
replace to BoringModel
sidhantls Dec 13, 2020
cd8c909
remove unrequired comment
sidhantls Dec 13, 2020
47221df
deprecate _every_epoch
sidhantls Dec 13, 2020
bee1d03
add deprecated argument to trainer
sidhantls Dec 13, 2020
273b826
test case for deprecated arg
sidhantls Dec 13, 2020
51307d7
remove unrequired assertion in train loop
sidhantls Dec 13, 2020
036b7f6
modify misconfig exception for int
sidhantls Dec 13, 2020
6eed702
conv bool to int of depreciated _every_epoch
sidhantls Dec 13, 2020
e54256a
update description of deprecated param
sidhantls Dec 13, 2020
8efabab
update deprecation warning
sidhantls Dec 13, 2020
51de4c6
modify argument to int only
sidhantls Dec 13, 2020
3fa2d77
fix deprecated test function name
sidhantls Dec 14, 2020
1f998af
merge tests for reload dls
sidhantls Dec 14, 2020
ad86dc5
add propery should reload dl
sidhantls Dec 14, 2020
974286a
removed and added to trainer property
sidhantls Dec 14, 2020
9ccf1cd
use property in train loop
sidhantls Dec 14, 2020
67dac16
remove deprecated test
sidhantls Dec 14, 2020
d6278cd
add deprecated test to new file
sidhantls Dec 14, 2020
7d0f894
test case for exception
sidhantls Dec 14, 2020
0aaf3e2
update test datamodule every_n_epochs
sidhantls Dec 14, 2020
43ad55d
update trainer docs
sidhantls Dec 14, 2020
d870f86
update hooks with every_n_epochs
sidhantls Dec 14, 2020
c0e08ce
edit format if statement
sidhantls Dec 14, 2020
49032a9
Update CHANGELOG.md
sidhantls Dec 14, 2020
0891acf
Apply suggestions from code review
sidhantls Dec 14, 2020
dedc5e7
typo in exception
sidhantls Dec 14, 2020
f1d6489
pytest check only misconfig exception
sidhantls Dec 14, 2020
8f699f6
remove unnecessary code in test
sidhantls Dec 14, 2020
f4d665f
remove unnecessary code in deprec test
sidhantls Dec 14, 2020
64755be
added match in test
sidhantls Dec 14, 2020
8095a47
typo in comment
sidhantls Dec 14, 2020
3d3601b
revert to prev, keep only req in context manager
sidhantls Dec 15, 2020
56cd08a
Apply suggestions from code review
rohitgr7 Dec 18, 2020
8f61f7e
docs
rohitgr7 Dec 18, 2020
dea34c4
rebase
rohitgr7 Jan 31, 2021
a264bfa
Apply suggestions from code review
Borda May 11, 2021
1d19ed2
Merge branch 'master' into master
sidhantls May 14, 2021
03915af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
327f207
fix import: model_helpers instead of model_utils
sidhantls May 14, 2021
c4f251a
fix, add reload_dataloaders_every_n_epochs argument to data connector
sidhantls May 14, 2021
eecfd0f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
0a2b6ad
add required imports
sidhantls May 14, 2021
90eb410
move deprecated log
sidhantls May 14, 2021
029670b
add missing import rank_zero_warn
sidhantls May 14, 2021
f8bf40b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
113c393
update varname in should_reload_dl_epoch
sidhantls May 14, 2021
2aa1fad
Fix CHANGELOG. Update deprecation versions
carmocca May 17, 2021
c3b3b69
Minor change
carmocca May 17, 2021
2b516a3
change property name, mark protected
sidhantls May 18, 2021
459355b
update property name
sidhantls May 18, 2021
5a530cf
update property name
sidhantls May 18, 2021
5e22d0c
Merge branch 'master' into sid-sundrani/master
akihironitta Jun 24, 2021
041fca8
Remove deprecated *_loop.py files
akihironitta Jun 24, 2021
bd35756
Merge branch 'master' into sid-sundrani/master
akihironitta Jun 24, 2021
dea1011
Rename test func
akihironitta Jun 24, 2021
12903c4
Merge branch 'master' into sid-sundrani/master
carmocca Jun 24, 2021
11cee0e
Update CHANGELOG.md
awaelchli Jul 6, 2021
8fc2fa5
use rank_zero_deprecation
awaelchli Jul 6, 2021
2cacfbb
update deprecation message in trainer api docs
awaelchli Jul 6, 2021
0a4f87f
test deprecation with real arg name in message
awaelchli Jul 6, 2021
fe1ca87
fix typo in trainer docs
awaelchli Jul 6, 2021
e173cee
Merge branch 'master' into master
awaelchli Jul 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- MLflowLogger now uses the env variable `MLFLOW_TRACKING_URI` as default tracking uri ([#7457](https://github.com/PyTorchLightning/pytorch-lightning/pull/7457))

- Changed `Trainer` arg and functionality from `reload_dataloaders_every_epoch` to `reload_dataloaders_every_n_epochs` ([#5043](https://github.com/PyTorchLightning/pytorch-lightning/pull/5043))

- Changed `WandbLogger(log_model={True/'all'})` to log models as artifacts ([#6231](https://github.com/PyTorchLightning/pytorch-lightning/pull/6231))

Expand Down Expand Up @@ -288,6 +289,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `mode` parameter in `ModelSummary` in favor of `max_depth` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))


- Deprecated `reload_dataloaders_every_epoch` argument of `Trainer` in favor of `reload_dataloaders_every_n_epochs` ([#5043](https://github.com/PyTorchLightning/pytorch-lightning/pull/5043))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down Expand Up @@ -708,6 +712,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#6323](https://github.com/PyTorchLightning/pytorch-lightning/pull/6323),
[#6211](https://github.com/PyTorchLightning/pytorch-lightning/pull/6211))


## [1.2.9] - 2021-04-20

### Fixed
Expand Down Expand Up @@ -752,8 +757,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))


- Fixed bug where `predict` could not be used when `progress_bar_refresh_rate=0` ([#6884](https://github.com/PyTorchLightning/pytorch-lightning/pull/6884))


Expand Down
13 changes: 7 additions & 6 deletions docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1297,8 +1297,8 @@ Note:
Lightning will set it to 20 in these environments if the user does not provide a value.
- This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.

reload_dataloaders_every_epoch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
reload_dataloaders_every_n_epochs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. raw:: html

Expand All @@ -1308,19 +1308,20 @@ reload_dataloaders_every_epoch

|

Set to True to reload dataloaders every epoch.
Set to a postive integer to reload dataloaders every n epochs.

.. code-block:: python

# if False (default)
# if 0 (default)
train_loader = model.train_dataloader()
for epoch in epochs:
for batch in train_loader:
...

# if True
# if a positive integer
for epoch in epochs:
train_loader = model.train_dataloader()
if not epoch % reload_dataloaders_every_n_epochs:
train_loader = model.train_dataloader()
for batch in train_loader:

.. _replace-sampler-ddp:
Expand Down
15 changes: 9 additions & 6 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,9 @@ def train_dataloader(self) -> TRAIN_DATALOADERS:
A collection of :class:`torch.utils.data.DataLoader` specifying training samples.
In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`.

The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
The dataloader you return will not be reloaded unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
a positive integer.

For data processing use the following pattern:

Expand Down Expand Up @@ -505,8 +506,9 @@ def test_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for testing.

The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
The dataloader you return will not be reloaded unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
a postive integer.

For data processing use the following pattern:

Expand Down Expand Up @@ -565,8 +567,9 @@ def val_dataloader(self) -> EVAL_DATALOADERS:
r"""
Implement one or multiple PyTorch DataLoaders for validation.

The dataloader you return will not be called every epoch unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
The dataloader you return will not be reloaded unless you set
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_n_epochs` to
a positive integer.

It's recommended that all data downloads and preparation happen in :meth:`prepare_data`.

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def reload_evaluation_dataloaders(self) -> None:
model = self.trainer.lightning_module
if self.trainer.testing:
self.trainer.reset_test_dataloader(model)
elif self.trainer.val_dataloaders is None or self.trainer.reload_dataloaders_every_epoch:
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_dl_epoch:
self.trainer.reset_val_dataloader(model)

def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def on_advance_start(self) -> None:
model = self.trainer.lightning_module

# reset train dataloader
if self.current_epoch != 0 and self.trainer.reload_dataloaders_every_epoch:
if self.current_epoch != 0 and self.trainer._should_reload_dl_epoch:
self.trainer.reset_train_dataloader(model)

# TODO: specify the possible exception
Expand Down
23 changes: 21 additions & 2 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytorch_lightning as pl
from pytorch_lightning.trainer.supporters import prefetch_iterator
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand All @@ -28,7 +29,11 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
self.multiple_trainloader_mode = multiple_trainloader_mode

def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
self,
check_val_every_n_epoch: int,
reload_dataloaders_every_n_epochs: int,
reload_dataloaders_every_epoch: bool,
prepare_data_per_node: bool,
) -> None:
self.trainer.datamodule = None
self.trainer.prepare_data_per_node = prepare_data_per_node
Expand All @@ -39,7 +44,21 @@ def on_trainer_init(
)

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

if reload_dataloaders_every_epoch:
reload_dataloaders_every_n_epochs = int(reload_dataloaders_every_epoch)
rank_zero_deprecation(
"`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed in v1.6."
" Please use `reload_dataloaders_every_n_epochs` in Trainer."
)

if not isinstance(reload_dataloaders_every_n_epochs, int) or (reload_dataloaders_every_n_epochs < 0):
raise MisconfigurationException(
"`reload_dataloaders_every_n_epochs` should be an int >= 0,"
f" got {reload_dataloaders_every_n_epochs}."
)

self.trainer.reload_dataloaders_every_n_epochs = reload_dataloaders_every_n_epochs
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TrainerProperties(ABC):
accelerator_connector: AcceleratorConnector
callbacks: List[Callback]
checkpoint_connector: CheckpointConnector
reload_dataloaders_every_n_epochs: int
limit_val_batches: int
logger: LightningLoggerBase
logger_connector: LoggerConnector
Expand Down Expand Up @@ -293,6 +294,12 @@ def progress_bar_dict(self) -> dict:
)
return {**standard_metrics, **pbar_metrics}

@property
def _should_reload_dl_epoch(self) -> bool:
""" Check if dataloader should be reloaded in the current epoch. """
n_epochs = self.reload_dataloaders_every_n_epochs
return n_epochs and (not self.current_epoch % n_epochs)

@property
def disable_validation(self) -> bool:
""" Check if validation is disabled during training. """
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
profiler: Optional[Union[BaseProfiler, str]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
Expand Down Expand Up @@ -272,8 +273,15 @@ def __init__(
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders.

reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs.
Default: 0
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch.

.. deprecated:: v1.4
``reload_dataloaders_every_epoch`` has been deprecated in v1.4 and will be removed in v1.6.
Please use ``reload_dataloaders_every_n_epochs``.

replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
Expand Down Expand Up @@ -382,7 +390,8 @@ def __init__(

# init data flags
self.data_connector.on_trainer_init(
check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node
check_val_every_n_epoch, reload_dataloaders_every_n_epochs, reload_dataloaders_every_epoch,
prepare_data_per_node
)

# init training tricks
Expand Down
14 changes: 8 additions & 6 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,9 +455,11 @@ def transfer_batch_to_device(self, batch, device, dataloader_idx):
assert torch.allclose(batch_gpu.targets.cpu(), torch.ones(5, 1, dtype=torch.long) * 2)


def test_dm_reload_dataloaders_every_epoch(tmpdir):
"""Test datamodule, where trainer argument
reload_dataloaders_every_epoch is set to True/False"""
def test_dm_reload_dataloaders_every_n_epochs(tmpdir):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""
Test datamodule, where trainer argument
reload_dataloaders_every_n_epochs is set to a non negative integer
"""

class CustomBoringDataModule(BoringDataModule):

Expand All @@ -482,9 +484,9 @@ def train_dataloader(self):

trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=0.01,
reload_dataloaders_every_epoch=True,
max_epochs=3,
limit_train_batches=2,
reload_dataloaders_every_n_epochs=2,
)
trainer.fit(model, dm)

Expand Down
22 changes: 22 additions & 0 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ def test_v1_6_0_ddp_spawn_sync_batchnorm():
DDPSpawnPlugin(sync_batchnorm=False)


def test_v1_6_0_reload_dataloaders_every_epoch(tmpdir):

model = BoringModel()

with pytest.deprecated_call(match='`reload_dataloaders_every_epoch` is deprecated in v1.4 and will be removed'):
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_epoch=True,
max_epochs=3,
)
trainer.fit(model)
trainer.test()

# verify the sequence
calls = trainer.dev_debugger.dataloader_sequence_calls
expected_sequence = ['val_dataloader'] + ['train_dataloader', 'val_dataloader'] * 3 + ['test_dataloader']
for call, expected in zip(calls, expected_sequence):
assert call['name'] == expected


def test_v1_6_0_tbptt_reduce_fx(tmpdir):

class TestModel(BoringModel):
Expand Down
44 changes: 23 additions & 21 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def test_dataloaders_load_only_once_val_interval(tmpdir):
limit_train_batches=10,
limit_val_batches=10,
val_check_interval=0.3,
reload_dataloaders_every_epoch=True,
reload_dataloaders_every_n_epochs=True,
max_epochs=3,
)
trainer.fit(model)
Expand Down Expand Up @@ -1368,44 +1368,46 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
assert call['name'] == expected


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_dataloaders_load_every_epoch(tmpdir):
@pytest.mark.parametrize("n", [1, 2])
def test_dataloaders_load_every_n_epochs(tmpdir, n):

model = EvalModelTemplate()
model = BoringModel()

# logger file to get meta
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=0.3,
limit_val_batches=0.3,
reload_dataloaders_every_epoch=True,
reload_dataloaders_every_n_epochs=n,
max_epochs=3,
)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"

trainer.test()

assert len(trainer.dev_debugger.val_dataloader_calls) == 4
assert len(trainer.dev_debugger.train_dataloader_calls) == 3
assert len(trainer.dev_debugger.test_dataloader_calls) == 1

# verify the sequence
calls = trainer.dev_debugger.dataloader_sequence_calls
expected_sequence = [
'val_dataloader',
'train_dataloader',
'val_dataloader',
'train_dataloader',
'val_dataloader',
'train_dataloader',
'val_dataloader',
'test_dataloader',
]
expected_sequence = ['val_dataloader']
if n == 1:
expected_sequence += ['train_dataloader', 'val_dataloader'] * 3
elif n == 2:
expected_sequence += ['train_dataloader', 'val_dataloader'] * 2
expected_sequence += ['test_dataloader']

for call, expected in zip(calls, expected_sequence):
assert call['name'] == expected


@pytest.mark.parametrize("n", ['test', -1])
def test_dataloaders_load_every_n_epochs_exception(tmpdir, n):

with pytest.raises(MisconfigurationException, match='should be an int >'):
Trainer(
default_root_dir=tmpdir,
reload_dataloaders_every_n_epochs=n,
)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_dataloaders_load_every_epoch_no_sanity_check(tmpdir):

Expand All @@ -1426,7 +1428,7 @@ def validation_step(self, batch, batch_idx):
limit_train_batches=0.3,
limit_val_batches=0.3,
num_sanity_val_steps=0,
reload_dataloaders_every_epoch=True,
reload_dataloaders_every_n_epochs=True,
max_epochs=3,
callbacks=[checkpoint_callback],
)
Expand Down