diff --git a/examples/cars/train.py b/examples/cars/train.py index 5131979b..630b3c05 100644 --- a/examples/cars/train.py +++ b/examples/cars/train.py @@ -35,6 +35,7 @@ def train( Quaterion.fit( trainable_model=model, + trainer=None, train_dataloader=train_dataloader, val_dataloader=val_dataloader, ) diff --git a/quaterion/main.py b/quaterion/main.py index bf5b4fda..220d6873 100644 --- a/quaterion/main.py +++ b/quaterion/main.py @@ -28,9 +28,9 @@ class Quaterion: def fit( cls, trainable_model: TrainableModel, + trainer: Optional[pl.Trainer], train_dataloader: SimilarityDataLoader, val_dataloader: Optional[SimilarityDataLoader] = None, - trainer: Optional[pl.Trainer] = None, ckpt_path: Optional[str] = None, ): """Handle training routine @@ -39,12 +39,15 @@ def fit( Args: trainable_model: model to fit + trainer: + `pytorch_lightning.Trainer` instance to handle fitting routine internally. + If `None` passed, trainer will be created with :meth:`Quaterion.trainer_defaults`. + The default parameters are intended to serve as a quick start for learning the model, and we + encourage users to try different parameters if the default ones do not give a satisfactory result. train_dataloader: DataLoader instance to retrieve samples during training stage val_dataloader: Optional DataLoader instance to retrieve samples during validation stage - trainer: `pytorch_lightning.Trainer` instance to handle fitting routine - internally. If not passed will be created with `Quaterion.trainer_defaults` ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. @@ -132,6 +135,21 @@ def trainer_defaults( regular deep learning model training. This default parameters may be overwritten, if you need some special behaviour for your special task. + Consider overriding default parameters if you need to adjust Trainer parameters: + + Example:: + + trainer_kwargs = Quaterion.trainer_defaults( + trainable_model=model, + train_dataloader=train_dataloader + ) + trainer_kwargs['logger'] = pl.loggers.WandbLogger( + name="example_model", + project="example_project", + ) + trainer_kwargs['callbacks'].append(YourCustomCallback()) + trainer = pl.Trainer(**trainer_kwargs) + Args: trainable_model: We will try to adjust default params based on model configuration, if provided train_dataloader: If provided, trainer params will be adjusted according to dataset