Skip to content

Commit

Permalink
do not set default trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
generall committed Jun 8, 2022
1 parent 79243e2 commit 1276aeb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/cars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def train(

Quaterion.fit(
trainable_model=model,
trainer=None,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
)
Expand Down
24 changes: 21 additions & 3 deletions quaterion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class Quaterion:
def fit(
cls,
trainable_model: TrainableModel,
trainer: Optional[pl.Trainer],
train_dataloader: SimilarityDataLoader,
val_dataloader: Optional[SimilarityDataLoader] = None,
trainer: Optional[pl.Trainer] = None,
ckpt_path: Optional[str] = None,
):
"""Handle training routine
Expand All @@ -39,12 +39,15 @@ def fit(
Args:
trainable_model: model to fit
trainer:
`pytorch_lightning.Trainer` instance to handle fitting routine internally.
If `None` passed, trainer will be created with :meth:`Quaterion.trainer_defaults`.
The default parameters are intended to serve as a quick start for learning the model, and we
encourage users to try different parameters if the default ones do not give a satisfactory result.
train_dataloader: DataLoader instance to retrieve samples during training
stage
val_dataloader: Optional DataLoader instance to retrieve samples during
validation stage
trainer: `pytorch_lightning.Trainer` instance to handle fitting routine
internally. If not passed will be created with `Quaterion.trainer_defaults`
ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is
no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
Expand Down Expand Up @@ -132,6 +135,21 @@ def trainer_defaults(
regular deep learning model training.
This default parameters may be overwritten, if you need some special behaviour for your special task.
Consider overriding default parameters if you need to adjust Trainer parameters:
Example::
trainer_kwargs = Quaterion.trainer_defaults(
trainable_model=model,
train_dataloader=train_dataloader
)
trainer_kwargs['logger'] = pl.loggers.WandbLogger(
name="example_model",
project="example_project",
)
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)
Args:
trainable_model: We will try to adjust default params based on model configuration, if provided
train_dataloader: If provided, trainer params will be adjusted according to dataset
Expand Down

0 comments on commit 1276aeb

Please sign in to comment.