Skip to content

Commit

Permalink
[doc] Update Dict Train Loader doc. (#6579)
Browse files Browse the repository at this point in the history
* update doc

* update example
  • Loading branch information
tchaton authored Mar 18, 2021
1 parent 9e35f97 commit 8853a36
Showing 1 changed file with 42 additions and 8 deletions.
50 changes: 42 additions & 8 deletions docs/source/advanced/multiple_loaders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Multiple Datasets
Lightning supports multiple dataloaders in a few ways.

1. Create a dataloader that iterates multiple datasets under the hood.
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
2. In the training loop you can pass multiple loaders as a dict or list/tuple and lightning
will automatically combine the batches from different loaders.
3. In the validation and test loop you also have the option to return multiple dataloaders
which lightning will call sequentially.
Expand Down Expand Up @@ -75,21 +75,38 @@ For more details please have a look at :paramref:`~pytorch_lightning.trainer.tra

loader_a = torch.utils.data.DataLoader(range(6), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(15), batch_size=5)

# pass loaders as a dict. This will create batches like this:
# {'a': batch from loader_a, 'b': batch from loader_b}
loaders = {'a': loader_a,
'b': loader_b}

# OR:
# OR:
# pass loaders as sequence. This will create batches like this:
# [batch from loader_a, batch from loader_b]
loaders = [loader_a, loader_b]

return loaders

Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
be returned
be returned.

.. testcode::

class LitModel(LightningModule):

def train_dataloader(self):

loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
loader_b = torch.utils.data.DataLoader(range(16), batch_size=2)

return {'a': loader_a, 'b': loader_b}

def training_step(self, batch, batch_idx):
# access a dictionnary with a batch from each dataloader
batch_a = batch["a"]
batch_b = batch["b"]


.. testcode::

Expand All @@ -103,12 +120,29 @@ be returned
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)

# pass loaders as a nested dict. This will create batches like this:
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
loaders = {
'loaders_a_b': {
'a': loader_a,
'b': loader_b
},
'loaders_c_d': {
'c': loader_c,
'd': loader_d
}
}
return loaders

def training_step(self, batch, batch_idx):
# access the data
batch_a_b = batch["loaders_a_b"]
batch_c_d = batch["loaders_c_d"]

batch_a = batch_a_b["a"]
batch_b = batch_a_b["a"]

batch_c = batch_c_d["c"]
batch_d = batch_c_d["d"]

----------

Test/Val dataloaders
Expand Down

0 comments on commit 8853a36

Please sign in to comment.