Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is hparams really a good practice? #1735

Closed
elkotito opened this issue May 5, 2020 · 18 comments
Closed

Is hparams really a good practice? #1735

elkotito opened this issue May 5, 2020 · 18 comments
Labels
discussion In a discussion stage help wanted Open to be worked on question Further information is requested
Milestone

Comments

@elkotito
Copy link
Contributor

elkotito commented May 5, 2020

❓ Questions and Help

I am a bit confused about good practices in PyTorchLightning, having in mind hparams in particular. I will provide some of my thoughts about this topic. The docs says:

# YES
model = LitModel(hparams)
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)

# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)
# trainer = Trainer(gpus=hparams.gpus, ...)
  1. Does it allow to parametrize EarlyStopping callback e.g. patience? The only way I can think of is using bad practice.
def main(hparams):
    model = LitModel(hparams)

    early_stop_callback = EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=hparams.patience,
    )

    trainer = Trainer.from_argparse_args(
        hparams,
        early_stop_callback=early_stop_callback,
    )

    trainer.fit(model)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--patience', type=int, default=5)

    parser = LitModel.add_model_specific_args(parser)
    parser = Trainer.add_argparse_args(parser)
    hparams = parser.parse_args()

    main(hparams)

Please let me know how to do it in a nice way.

  1. What does happen if Trainer and LitModel both use the same argument name for different purposes? In the docs, they are combined into a parent parser.
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    # figure out which model to use
    parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # let the model add what it wants
    if temp_args.model_name == 'gan':
        parser = GoodGAN.add_model_specific_args(parser)
    elif temp_args.model_name == 'mnist':
        parser = LitMNIST.add_model_specific_args(parser)

    args = parser.parse_args()

Perhaps a good practice is to parse them separately?

  1. I don't think that using add_model_specific_args static method & pass only hparams into __init__ is a good idea.
class LitModel(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument(...)
        ...
        return parser

a) If I properly assigned __init__ arguments into self, then IDE would suggest to me the Unresolved reference error before running the code. It happened to me many times to have some leftovers hanging out because of this hparams.

b) Basically add_model_specific_args requires me to track all occurrences of self.hparams in the code visually! If I properly assigned __init__ arguments into self, I would need just to take a simple look at the __init__ section.

I think I know the reasons why hparams was introduced (easy tracking in tensorboard or rapid research), but I believe that this concept should be reconsidered.

@elkotito elkotito added the question Further information is requested label May 5, 2020
@Borda Borda added discussion In a discussion stage help wanted Open to be worked on labels May 5, 2020
@awaelchli
Copy link
Contributor

What I like about the add_model_specific_args is that it makes the LightningModule "portable", meaning that I can have multiple trainings scripts for the same module, and I don't have to duplicate the model specific args or keep them in a separate file.

@elkotito
Copy link
Contributor Author

elkotito commented May 6, 2020

@awaelchli I agree, but add_model_specific_args basically plays the role of __init__ method, so perhaps a method like Trainer.add_argparse_args is better, because you can autogenerate ArgumentParser (see Trainer.get_init_arguments_and_types).

@awaelchli
Copy link
Contributor

awaelchli commented May 9, 2020

You raise multiple issues in your first post and I just commented on one. You can't agree with me and then in the same sentence disagree with everything I said xD That doesn't make sense haha.

Anyway, I will give my thoughts on all the remaining points:

Your point (1): One possibility would be to add the same helper functions from Trainer to the callbacks and loggers, e.g., allowing to construct it with "callback = Callback.from_argparse_args".

Your point (2): What would happen is that the argparser would complain that you are trying to add the same argument a second time.

Perhaps a good practice is to parse them separately?

Yes I agree, and it is already possible. Lightning does not prevent us from doing it. I do this in my script so that I can only pass the true hparams to my model and don't have any Trainer args in there. This can be achieved by two parsers and then using parser.parse_known_args for each of them.

Your point (3): We store the hparams in a dict (or Namespace) because otherwise it would not be so easy to restore the model for the user (e.g. when loading from checkpoint). We typically want to save the hyperparameters with the model.

I don't understand the rest about the IDE stuff you mention. I think what you want to say is that it is better to just plug all hyperparameters as arguments to the module, but with many hyperparams this becomes a bit hard to manage. Again, how would Lightning know where to grab the hparams?

Don't get me wrong, I'm open to changes. I just also see many advantages of the current setup (best practice?)

@elkotito
Copy link
Contributor Author

@awaelchli Seems like we misunderstood each other? 😉

What I like about the add_model_specific_args is that it makes the LightningModule "portable", meaning that I can have multiple trainings scripts for the same module, and I don't have to duplicate the model specific args or keep them in a separate file.

First of all, I don't really understand why do you need different training scripts for the same module. I thought the reasons behind having Trainer and LightningModule is to make training script simple, meaning create two objects, add some additional loggers or callbacks, and call fit method. Please specify the use cases or code examples you had in mind that you would need to "duplicate model specific args" or "keep them in a separate file". Second of all, I don't think I proposed anything that contradicts what you wrote. If I did, then please provide an example.

You can even take a look at AllenNLP project, where you can compress the whole training script into command allennlp train, which is general enough to handle most cases. In my local PTL setting, I ended up with one train.py file, where I provide a model class and all args from the command line because I don't want to maintain multiple training scripts for different models.

I don't understand the rest about the IDE stuff you mention. I think what you want to say is that it is better to just plug all hyperparameters as arguments to the module, but with many hyperparams this becomes a bit hard to manage. Again, how would Lightning know where to grab the hparams?

What I mentioned is expressing a model definition in the __init__ method, which is just a valid software development principle. If a new person doesn't know anything about the class, then the person looks for __init__ method to know the usage rules. I agree that it can messy, but at least you know what the model expects (you will detect missing references whilst the object is created, not during training & you have IDE support also). You can also ask yourself why Trainer has a proper __init__, not just single trainer_params, even though the number of arguments is also overwhelming.

Take in mind that you can always initialize the object using Model(**hparams). Such form allows you to use same model for multiple training scripts as long as the training scripts are written properly. Take a look at Trainer class i.e. get_init_arguments_and_types and from_argparse_args methods.

Your point (3): We store the hparams in a dict (or Namespace) because otherwise it would not be so easy to restore the model for the user (e.g. when loading from checkpoint). We typically want to save the hyperparameters with the model.

Perhaps a good point, but I don't know PTL pretty well. Can you elaborate more on what the problem is?

Thoughts

I think the topic itself relates very much to how the training takes place. Although it relates to Hydra integration, a similar discussion can be found here #807. @yukw777 looks for the cleanest approach to integrate this, ending up with something considered a bad practice in PTL docs:

DataLoader(
            Dataset.from_data_dir(cfg.dataset.train.dir_path, transform=True),
            shuffle=True,
            batch_size=cfg.dataset.train.batch_size,
            num_workers=cfg.dataset.train.num_workers,
        ),

In this case, I don't think it's a good idea to override DataLoader as you suggested with
callback = Callback.from_argparse_args.

In addition:

    module = NetworkLightningModule(OmegaConf.to_container(cfg, resolve=True))
    trainer = Trainer(**OmegaConf.to_container(cfg.pl_trainer, resolve=True))

seems very alike, meaning that LightningModule could follow a proper constructor.

@yukw777
Copy link
Contributor

yukw777 commented May 11, 2020

@mateuszpieniak hparams is necessary for automatic deserialization of models from a single checkpoint file. PyTorch only provides functionalities to save parameter weights, so you need to construct an "empty" model before loading the weights. hparams is what PL uses to "remember" how to construct an empty model. Without it, you'd have to manually pass in the arguments to the __init__() of your model yourself.

I don't quite get why you think specifying DataLoader like I did is a bad practice. It's a valid pattern for production models. https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#pytorch_lightning.trainer.Trainer.fit.params.train_dataloader

@elkotito
Copy link
Contributor Author

elkotito commented May 11, 2020

@yukw777

I don't quite get why you think specifying DataLoader like I did is a bad practice. It's a valid pattern for production models. https://pytorch-lightning.readthedocs.io/en/stable/trainer.html#pytorch_lightning.trainer.Trainer.fit.params.train_dataloader

When it comes to DataLoader I don't think that either 😉 I think that the way you initialized a DataLoader is perfectly fine. The docs say something different about object creation. I referred to the idea of adding into EarlyStopping.from_argparse_args, which is fine, but requires also to add such method for many other classes like DataLoader just to meet the "good practice" from the docs.

# YES
model = LitModel(hparams)
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)

# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)
# trainer = Trainer(gpus=hparams.gpus, ...)

It's fair to point that the docs refer only to Model and Trainer class, but I generalized to other cases. What is this "good practice" about then? I see a pattern "don't use hparams.xxx to create your object". If my pattern is wrong, please provide a correct pattern from the given examples. Take in mind that if you need specific rules or different good practices only for specific objects, perhaps you "overfit" with your coding 😉

hparams is necessary for automatic deserialization of models from a single checkpoint file. PyTorch only provides functionalities to save parameter weights, so you need to construct an "empty" model before loading the weights. hparams is what PL uses to "remember" how to construct an empty model. Without it, you'd have to manually pass in the arguments to the init() of your model yourself.

So if you need to "construct an "empty" model before loading the weights." use a constructor for that, hence the name in object-oriented programming 😉 For example, there is a reason behind __setstate__ and __getstate__ in Python. If you want a more general solution for that, you can also investigate your __init__ signature as you partly do in get_init_arguments_and_types.

As you said hparams is a necessary condition to recreate the object, because it's the only one in __init__ method, meaning that its existence is not a necessary condition - there exists a different design pattern.

My whole point with that initialization is that I find it an anti-pattern, not a good practice. You can also see that under this link. PL basically does the same, having an unknown dynamic structure (dict), that you dynamically read from.

@yukw777
Copy link
Contributor

yukw777 commented May 11, 2020

@mateuszpieniak ah i see what you mean about the way I initialize DataLoaders.

Another thing to consider about hparams is that it needs to be picklable so that it can be saved as part of the checkpoint file. We could inspect the signature of __init__() but we can't really tell which one is picklable or not. An example of an unpicklable object would be the tokenizer from HuggingFace (this is written in Rust with Python bindings, so can't be pickled).

Having said that, I do think that hparams needs to be looked at again. Should we just have people override __getnewargs_ex__(), __getstate__() and __setstate__()? Maybe the base LightningModule could do the bare minimum and save parameter weights in __getstate__() and load them in __setstate__()?

I believe this discussion is also related to: #1755

@yukw777
Copy link
Contributor

yukw777 commented May 12, 2020

Also related: #1766

@yukw777
Copy link
Contributor

yukw777 commented May 13, 2020

Played around with __getnewargs_ex__(), __getstate__() and __setstate__(), and they don't quite work well, b/c unpickling doesn't call __init__() which doesn't play nicely with nn.Module. Also, pickle still preserves the information about the class itself, which we don't want to ensure portability of our weight files.

Maybe we need to have a new abstract method on pl.LightningModule that would return the constructor arguments and save this as part of the checkpoint. Something like this...

class LightningModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.hparams = {...}  # we can still have hparams initialized here for logging.

    @property
    @abstractmethod
    def constructor_args(self) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
        """Everything returned by this needs to be picklable"""
        return (args, kwargs)

    @classmethod
    def _load_model_state(cls, checkpoint: Dict[str, Any]) -> 'LightningModule':
        args, kwargs = checkpoint['constructor_args']
        model = cls(*args, **kwargs)
        model.load_state_dict(checkpoint['state_dict'])

        # give model a chance to load something
        model.on_load_checkpoint(checkpoint)
        return model

class TrainerIOMixin(ABC):
    def dump_checkpoint(self):
        # add the hparams and state_dict from the model
        model = self.get_model()

        checkpoint['constructor_args'] = model.constructor_args()
        checkpoint['state_dict'] = model.state_dict()

        # save native amp scaling
        if self.use_amp and self.use_native_amp:
            checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()

        # give the model a chance to add a few things
        model.on_save_checkpoint(checkpoint)

        return checkpoint

@yukw777
Copy link
Contributor

yukw777 commented May 13, 2020

@awaelchli @Borda @williamFalcon thoughts?

@elkotito
Copy link
Contributor Author

@yukw777

We could inspect the signature of init() but we can't really tell which one is picklable or not.

The same logic above also applies to self.hparams. If we assume that self.hparams is a dictionary, then we can easily pass a non-pickable argument there as well, meaning that it has to be inspected anyway. I understood that the workaround for such a case was to "kindly enforce" the user to initialize a LightingModule with argparse, which supports fairly easy datatypes that are "pickable". Having said that, I don't think it's fair to call it "good practice" when in reality those are just constraints to make PL work.

Now we come to the question of what can we do about it?

  1. Either way, we need some way to serialize class properly, because there is no guarantee that hparams has only simple types (unless we enforce that using proper typing, but then we still some logic this for different arguments and *args and **kwargs).

I imagined that for each __init__ parameter you can inspect it during serialization and convert them into some deserializable form. Currently, my thoughts are around hydra. I guess that similar issues were solved in AllenNLP with using registrable classes.

    self.hparams = {...}  # we can still have hparams initialized here for logging.

Sure, we can still have this, but the logger object has to be created on the same "code level" as the model object is. For example NeptuneLogger has __init__ argument params, TensorboardLogger handles kwargs. I don't understand why it's so needed to do be done implicitly.

@williamFalcon
Copy link
Contributor

these aren't really constraints to make PL work... you can use PL all day long without hparams. The problems you'll eventually run into are:

  1. loading from checkpoint will be a PITA because your model will have no idea how it was designed (num layers, layer dims, etc...)
  2. you won't know what went into making this model.

So, as a best practice we usually store these details into a csv or into the checkpoint itself.

Do you have any other ways that might be better?

Ideally we don't need hparams but instead read what defaults the model was init with and save those. However, i don't think python has that ability.

@yukw777
Copy link
Contributor

yukw777 commented May 14, 2020

@williamFalcon what do you think about my proposal? Basically have people define how their models should be (de)serialized, and PL simply follows that. We can also have people specify the hyperparameters they want tracked with the loggers themselves as @mateuszpieniak suggested.

@elkotito
Copy link
Contributor Author

@williamFalcon

these aren't really constraints to make PL work... you can use PL all day long without hparams. The problems you'll eventually run into are:

My point is that you cannot call "a good practice" something that is the necessary condition for the functionality to work. We should at least specify that in docs that some functionality won't work unless you design your class in a certain way.

Do you have any other ways that might be better? Ideally we don't need hparams but instead read what defaults the model was init with and save those. However, i don't think python has that ability.

Now I see the reason! 😉 Currently I see two solutions to make it more OOP friendly. Frankly speaking I don't know whether they are better.

  1. Decorator around __init__ (I don't like it very much).
def save_for_checkpoint(init):
    @wraps(init)
    def wrapped_init(self, *args, **kwargs):
        init(self, *args, **kwargs)
        self.init_args = args
        self.init_kwargs = kwargs

    return wrapped_init


class MyModel(LightningModule):
    @save_for_checkpoint
    def __init__(self, arg1, arg2, *args, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs):
        pass
  1. Saving the state by passing an argument to the base class. It follows OOP, because every time you have a base class, you need to be aware that it might call some important actions for your class to work.
class LightningModule(...):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.init_args = args
        self.init_kwargs = kwargs

    def forward(self, *args, **kwargs):
        pass

class MyModel(LightningModule):
    def __init__(self, arg1, arg2, *args, **kwargs):
        super().__init__(arg1, arg2, *args, **kwargs)
        self.arg1 = arg1
        self.arg2 = arg2

Eventually I think that hparams is not super bad workaround, but I would suggest making it mandatory argument & specify the accepted type. We also need to mention in the docs that all arguments that are not passed through hparams require default value or needs to passed during loading from checkpoint.

@yukw777
Copy link
Contributor

yukw777 commented May 14, 2020

@mateuszpieniak nice suggestions! I like your second suggestion too. I guess I thought it'd be better to have the user be more explicit on what they want to save rather than trying to save them automatically, hence my suggestion.. but i guess we can be more clever since we're basically just saving what needs to be passed to __init__().

@yukw777
Copy link
Contributor

yukw777 commented May 14, 2020

Thinking about it more.. I don't think Option 2 would work, @mateuszpieniak, when there are multiple levels of inheritance. Consider this:

class LightningModule(...):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.init_args = args
        self.init_kwargs = kwargs

    def forward(self, *args, **kwargs):
        pass

class MyParentModel(LightningModule):
    def __init__(self, p_arg1, p_arg2, p_kwargs1=default, p_kwargs2=default2):
        super().__init__(p_arg1, p_arg2, p_kwargs1=p_kwargs1, p_kwargs2=p_kwargs2)
        self.arg1 = arg1
        self.arg2 = arg2

class MyChildModel(LightningModule):
    def __init__(self, c_arg1, c_arg2, c_kwargs1=default):
        super().__init__(c_arg1, c_arg2, p_kwargs1=c_kwargs1)
        self.arg1 = arg1
        self.arg2 = arg2

Then, LightningModule would save self.init_args = (c_arg1, c_arg2) and self.init_kwargs = {'p_kwargs1': default, 'p_kwargs2': default2}. But this is a problem when we deserialize MyChildModel b/c it doesn't accept p_kwargs1 or p_kwargs2.

My suggestion and your first suggestion would work b/c it saves the constructor arguments at the leaf level.

@elkotito
Copy link
Contributor Author

elkotito commented May 15, 2020

@yukw777 You're right 😿
New thoughts coming!

  1. We could make our model class a dataclass. We would easily save the object's state simply reading from self. Such decorator makes sure they would exist, because it is responsible for creating the fields. In addition, it provides validation of what has been passed to the constructor, meaning no hanging kwargs and args that would break the serialization (the exceptions will be raised). Take in mind that it still supports object creation with unrolling kwargs or args like model = Model(*args, **kwargs). Now drawbacks:
  • Formally a person still can create __init__ method, but it doesn't make sense.
  • For custom init there exists __post_init__ method. In order to pass an argument to __post_init__, the type has to be wrapped up with InitVar. Perhaps only such a part would require a manual (de)serialization logic from the class user. I think it's a good trade-off between automatic serialization logic and manual one. It's similar to 1), but at least the decorator is aligned with the PEP standard. The data class arguments can be easily obtained with __annotations__.
  1. I took a look at AllenNLP, which follows a lot of good practices in general. Their whole training process is controlled via configuration files, meaning that objects are mostly created using fromParams(...) method, which is very similar to our Trainer.from_argparse_args. Such an approach gives you basically a "state descriptor" that is sufficient to reproduce the object's state. I don't really know what happens if you create an object directly using the __init__ method, but perhaps it's worth investigation. What should we do then? Raise an exception or warning that checkpoints are not supported? With such an approach and Hydra combined, we could create a "killing machine".

What do you think? I personally like both approaches, but 4) might be a bit risky since we fully rely on one fairly new technology. If we need to control the object's creation process, then we should consider a factory design pattern? I mean we don't just create the object explicitly, but we have some different object that controls that.

@williamFalcon
Copy link
Contributor

fixed on master

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage help wanted Open to be worked on question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants