Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dataloader with dicts #37

Merged
merged 8 commits into from
May 29, 2020
Merged

Conversation

AlexGrig
Copy link
Contributor

@AlexGrig AlexGrig commented May 8, 2020

If a DataLoader you are working with does not provide outputs of (Xs,Ys) but instead output dict with values then this PR allows to handle this. It is possible to redefine the DataLoader itself, but it is not always convenient. I have added an extra parameter to range_test with the class name. This class is supposed to be inherited from DataLoaderIterWrapper. In the new calss redefine method _batch_make_inputs_labels of the new class so that it return tuple (Xs, Ys). There is also an example of usage.

@NaleRaphael
Copy link
Contributor

Hi @AlexGrig, thanks for the contribution.

For this case, I would consider replacing the following line with a method as an overridable interface for users:

iter_wrapper = DataLoaderIterWrapper(train_loader)

(to something like this)

class LRFinder(object):
    def range_test(...):
        # ...

        # Create an iterator to get data batch by batch
        iter_wrapper = self.make_iterator(train_loader)

        # ...

    def make_iterator(self, dataloader):
        # Override this method to use a custom `DataLoaderIterWrapper`
        return DataLoaderIterWrapper(train_loader)

With this approach, users can handle their complex dataset and make it able to conform with the format we used. This implementation should be more convenient and make less changes to the original codebase.

class CustomDataloaderIterWrapper(object):
    def __next__(self):
        # ...
        return inputs, labels


class CustomLRFinder(LRFinder):
    def make_iterator(self, dataloader):
        return CustomDataloaderIterWrapper(dataloader)

And I think it would be better to wait for @davidtvs's response before making further changes.

@davidtvs
Copy link
Owner

davidtvs commented May 9, 2020

Thanks for the PR @AlexGrig.

I would actually suggest somewhat of a mix of the ideas from @AlexGrig and @NaleRaphael. My suggestion is to create a base class from the current DataLoaderIterWrapper:

class DataLoaderIter(object):
    def __init__(self, data_loader, auto_reset=True):
        self.data_loader = data_loader
        self.auto_reset = auto_reset
        self._iterator = iter(data_loader)

    def inputs_labels_from_batch(self, batch_data):
        inputs, labels, *_ = batch_data

        return inputs, labels

    def __next__(self):
        return self.inputs_labels_from_batch(batch)

The current DataLoaderIterWrapper becomes a child class that overrides the __next__ method:

class TrainDataLoaderIter(DataLoaderIter):
    def __next__(self):
        try:
            batch = next(self._iterator)
            inputs, labels = self.inputs_labels_from_batch(batch)
        except StopIteration:
            if not self.auto_reset:
                raise
            self._iterator = iter(self.data_loader)
            batch = next(self._iterator)
            inputs, labels = self.inputs_labels_from_batch(batch)

        return inputs, labels

The user can then create its own custom dataloader wrapper in a similar fashion:

class CustomTrainDataLoaderIter(TrainDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        # Manipulate the batch_data and to get inputs and labels
        # ...

        return inputs, labels

I've named these classes with the word Train because we actually have the same problem for the validation dataloader. It currently assumes that it returns a pair of inputs, labels but that might not be true for all users. So I suggest creating a ValDataLoaderIter

class ValDataLoaderIter(DataLoaderIter):
    pass

which can also be overridden like this:

class CustomValDataLoaderIter(ValDataLoaderIter):
    def inputs_labels_from_batch(self, batch_data):
        # Manipulate the batch_data to get inputs and labels
        # ...
        return inputs, labels

The only thing left to do is integrate this into the range_test method. For this, and to avoid API changes, I propose that the parameters train_loader and val_loader can either be instances of the DataLoader class or of the TrainDataLoaderIter and ValDataLoaderIter classes, respectively, and we would handle it like so:

class LRFinder(object):
    def range_test(...):
        # ...
        if isinstance(train_loader, DataLoader):
            train_iter = TrainDataLoaderIter(train_loader)
        else if isinstance(train_loader, TrainDataLoaderIter):
            train_iter = train_loader
        else:
            raise ValueError("...")
        
        if isinstance(val_loader, DataLoader):
            val_iter = ValDataLoaderIter(val_loader)
        else if isinstance(val_loader, ValDataLoaderIter):
            val_iter = val_loader
        else:
            raise ValueError("...")

Further changes would be needed within range_test, _train_batch, and _validate. I would like a better name to replace train_loader and val_loader but currently can't think of one.

What do you guys think?

@NaleRaphael
Copy link
Contributor

@davidtvs Great idea, I second this approach.

I would like a better name to replace train_loader and val_loader but currently can't think of one.

If I understand correctly, the idea of changing these names is to make them represent their own types unambiguously, i.e. they are DataLoaderIter instead of DataLoader.

In my opinion, it's fine to keep them unchanged since users usually pass DataLoader to it. And we can revise those parts related to train_loader and val_loader in the docstring to make users know that DataLoaderIter is also an acceptable type of argument. Add some examples of using DataLoaderIter for this case in README.md might be helpful, too.

def range_test(...):
    """
    Arguments:
        train_loader (torch.utils.data.DataLoader or DataLoaderIter): ...
        val_loader (torch.utils.data.DataLoader or DataLoaderIter, optional): ...
    """

@AlexGrig
Copy link
Contributor Author

AlexGrig commented May 12, 2020

Hello,

Thanks for the comments, they sound quite reasonable. I ll implement these changes in the a few days. If something comes up, I ll bring it for discussion here.

I also agree with @NaleRaphael that the interface to range_test should be natural to Pytorch users and they should be able just to use their DataLoaders.

Ok, now I looked closer. I have a question to @davidtvs:
I do not understand, why you propose two classes for training and validation: TrainDataLoaderIter and ValDataLoaderIter? It seems to me that there can be only one, and if a user needs a difference, he could inherit and define two different classes in his code.

@davidtvs
Copy link
Owner

If you notice TrainLoaderIter overrides the __next__ method while ValLoaderIter doesn't.

The difference between them is because the user specifies the number of iterations (num_iter in range_test) he wants for the range test (training) but the validation is always for a full epoch.
So the __next__ method in TrainLoaderIter needs to handle cases where the user specifies num_iter larger than the number of iterations in a single epoch of training. On the other hand, ValLoaderIter doesn't need to handle that because it always does one full epoch.

@AlexGrig
Copy link
Contributor Author

Ok, thanks @davidtvs I see now the meaning, I ll implement according to this approach

@AlexGrig
Copy link
Contributor Author

Hello, @davidtvs and @NaleRaphael
I have finished the planned modifications, if you have more comments feel free to express.

torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
@davidtvs
Copy link
Owner

/flake8-lint

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lintly has detected code quality issues in this pull request.

@@ -1 +1,3 @@
from torch_lr_finder.lr_finder import LRFinder
from torch_lr_finder.lr_finder import TrainDataLoaderIter

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F401: 'torch_lr_finder.lr_finder.TrainDataLoaderIter' imported but unused

@@ -1 +1,3 @@
from torch_lr_finder.lr_finder import LRFinder
from torch_lr_finder.lr_finder import TrainDataLoaderIter
from torch_lr_finder.lr_finder import ValDataLoaderIter

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F401: 'torch_lr_finder.lr_finder.ValDataLoaderIter' imported but unused

torch_lr_finder/__init__.py Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
torch_lr_finder/lr_finder.py Outdated Show resolved Hide resolved
@AlexGrig AlexGrig requested a review from davidtvs May 28, 2020 13:01
@davidtvs davidtvs merged commit 52c189a into davidtvs:master May 29, 2020
@davidtvs
Copy link
Owner

Merged. Thanks for helping!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants