-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dataloader with dicts #37
Conversation
…into dataloader_with_dicts
Hi @AlexGrig, thanks for the contribution. For this case, I would consider replacing the following line with a method as an overridable interface for users:
(to something like this) class LRFinder(object):
def range_test(...):
# ...
# Create an iterator to get data batch by batch
iter_wrapper = self.make_iterator(train_loader)
# ...
def make_iterator(self, dataloader):
# Override this method to use a custom `DataLoaderIterWrapper`
return DataLoaderIterWrapper(train_loader) With this approach, users can handle their complex dataset and make it able to conform with the format we used. This implementation should be more convenient and make less changes to the original codebase. class CustomDataloaderIterWrapper(object):
def __next__(self):
# ...
return inputs, labels
class CustomLRFinder(LRFinder):
def make_iterator(self, dataloader):
return CustomDataloaderIterWrapper(dataloader) And I think it would be better to wait for @davidtvs's response before making further changes. |
Thanks for the PR @AlexGrig. I would actually suggest somewhat of a mix of the ideas from @AlexGrig and @NaleRaphael. My suggestion is to create a base class from the current class DataLoaderIter(object):
def __init__(self, data_loader, auto_reset=True):
self.data_loader = data_loader
self.auto_reset = auto_reset
self._iterator = iter(data_loader)
def inputs_labels_from_batch(self, batch_data):
inputs, labels, *_ = batch_data
return inputs, labels
def __next__(self):
return self.inputs_labels_from_batch(batch) The current class TrainDataLoaderIter(DataLoaderIter):
def __next__(self):
try:
batch = next(self._iterator)
inputs, labels = self.inputs_labels_from_batch(batch)
except StopIteration:
if not self.auto_reset:
raise
self._iterator = iter(self.data_loader)
batch = next(self._iterator)
inputs, labels = self.inputs_labels_from_batch(batch)
return inputs, labels The user can then create its own custom dataloader wrapper in a similar fashion: class CustomTrainDataLoaderIter(TrainDataLoaderIter):
def inputs_labels_from_batch(self, batch_data):
# Manipulate the batch_data and to get inputs and labels
# ...
return inputs, labels I've named these classes with the word class ValDataLoaderIter(DataLoaderIter):
pass which can also be overridden like this: class CustomValDataLoaderIter(ValDataLoaderIter):
def inputs_labels_from_batch(self, batch_data):
# Manipulate the batch_data to get inputs and labels
# ...
return inputs, labels The only thing left to do is integrate this into the class LRFinder(object):
def range_test(...):
# ...
if isinstance(train_loader, DataLoader):
train_iter = TrainDataLoaderIter(train_loader)
else if isinstance(train_loader, TrainDataLoaderIter):
train_iter = train_loader
else:
raise ValueError("...")
if isinstance(val_loader, DataLoader):
val_iter = ValDataLoaderIter(val_loader)
else if isinstance(val_loader, ValDataLoaderIter):
val_iter = val_loader
else:
raise ValueError("...") Further changes would be needed within What do you guys think? |
@davidtvs Great idea, I second this approach.
If I understand correctly, the idea of changing these names is to make them represent their own types unambiguously, i.e. they are In my opinion, it's fine to keep them unchanged since users usually pass def range_test(...):
"""
Arguments:
train_loader (torch.utils.data.DataLoader or DataLoaderIter): ...
val_loader (torch.utils.data.DataLoader or DataLoaderIter, optional): ...
""" |
Hello, Thanks for the comments, they sound quite reasonable. I ll implement these changes in the a few days. If something comes up, I ll bring it for discussion here. I also agree with @NaleRaphael that the interface to range_test should be natural to Pytorch users and they should be able just to use their DataLoaders. Ok, now I looked closer. I have a question to @davidtvs: |
If you notice The difference between them is because the user specifies the number of iterations ( |
Ok, thanks @davidtvs I see now the meaning, I ll implement according to this approach |
Hello, @davidtvs and @NaleRaphael |
/flake8-lint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lintly has detected code quality issues in this pull request.
@@ -1 +1,3 @@ | |||
from torch_lr_finder.lr_finder import LRFinder | |||
from torch_lr_finder.lr_finder import TrainDataLoaderIter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
F401: 'torch_lr_finder.lr_finder.TrainDataLoaderIter' imported but unused
@@ -1 +1,3 @@ | |||
from torch_lr_finder.lr_finder import LRFinder | |||
from torch_lr_finder.lr_finder import TrainDataLoaderIter | |||
from torch_lr_finder.lr_finder import ValDataLoaderIter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
F401: 'torch_lr_finder.lr_finder.ValDataLoaderIter' imported but unused
Merged. Thanks for helping! |
If a DataLoader you are working with does not provide outputs of (Xs,Ys) but instead output dict with values then this PR allows to handle this. It is possible to redefine the DataLoader itself, but it is not always convenient. I have added an extra parameter to range_test with the class name. This class is supposed to be inherited from DataLoaderIterWrapper. In the new calss redefine method _batch_make_inputs_labels of the new class so that it return tuple (Xs, Ys). There is also an example of usage.