Skip to content

Commit

Permalink
new: add quaterion progress bar, add trainer defaults, move callbacks… (
Browse files Browse the repository at this point in the history
#114)

* new: add quaterion progress bar, add trainer defaults, move callbacks #113

* new: replace max epochs with early stopping callback #113

* adjust progress bar + more default params

* fixup: Format Python code with Black

* rich model summary

* rm unused imports

* review fix

* fixup: Format Python code with Black

* review fix

* fixup: Format Python code with Black

* do not set default trainer

Co-authored-by: Andrey Vasnetsov <andrey@vasnetsov.com>
Co-authored-by: autoblack <qdrant@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 8, 2022
1 parent 3078522 commit 83b901c
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 45 deletions.
15 changes: 1 addition & 14 deletions examples/cars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,9 @@ def train(
batch_size=batch_size, input_size=input_size, shuffle=shuffle
)

early_stopping = EarlyStopping(
monitor="validation_loss",
patience=50,
)

trainer = pl.Trainer(
gpus=1 if torch.cuda.is_available() else 0,
max_epochs=epochs,
callbacks=[early_stopping, ModelSummary(max_depth=3)],
enable_checkpointing=False,
log_every_n_steps=1,
)

Quaterion.fit(
trainable_model=model,
trainer=trainer,
trainer=None,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
)
Expand Down
76 changes: 56 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ loguru = "^0.5.3"
mmh3 = "^3.0.0"
pytorch-metric-learning = {version = "^1.3.0", optional = true}
protobuf = ">= 3.9.2, <3.20" # until https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 is fixed
rich = "^12.4.4"
torchmetrics = "<=0.8.2" # Intil warning with `full_state_update` is fixed


[tool.poetry.dev-dependencies]
Expand Down
108 changes: 97 additions & 11 deletions quaterion/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import warnings

import torch
from typing import Optional, Union, Sized, Iterable, Dict

import pytorch_lightning as pl
import torch
import warnings
from pytorch_lightning.callbacks import EarlyStopping, RichModelSummary
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import Dataset
from quaterion_models import SimilarityModel
from typing import Optional, Union, Sized, Iterable, Dict

from quaterion.dataset.similarity_data_loader import (
PairsSimilarityDataLoader,
Expand All @@ -15,9 +13,12 @@
)
from quaterion.eval.evaluator import Evaluator
from quaterion.loss import GroupLoss, PairwiseLoss
from quaterion.train.cleanup_callback import CleanupCallback
from quaterion.train.metrics_callback import MetricsCallback
from quaterion.train.cache import CacheType
from quaterion.train.callbacks import CleanupCallback, MetricsCallback
from quaterion.train.trainable_model import TrainableModel
from quaterion.utils.enums import TrainStage
from quaterion.utils.progress_bar import QuaterionProgressBar
from quaterion_models import SimilarityModel


class Quaterion:
Expand All @@ -27,7 +28,7 @@ class Quaterion:
def fit(
cls,
trainable_model: TrainableModel,
trainer: pl.Trainer,
trainer: Optional[pl.Trainer],
train_dataloader: SimilarityDataLoader,
val_dataloader: Optional[SimilarityDataLoader] = None,
ckpt_path: Optional[str] = None,
Expand All @@ -38,8 +39,11 @@ def fit(
Args:
trainable_model: model to fit
trainer: `pytorch_lightning.Trainer` instance to handle fitting routine
internally
trainer:
`pytorch_lightning.Trainer` instance to handle fitting routine internally.
If `None` passed, trainer will be created with :meth:`Quaterion.trainer_defaults`.
The default parameters are intended to serve as a quick start for learning the model, and we
encourage users to try different parameters if the default ones do not give a satisfactory result.
train_dataloader: DataLoader instance to retrieve samples during training
stage
val_dataloader: Optional DataLoader instance to retrieve samples during
Expand All @@ -62,6 +66,13 @@ def fit(
"Try other loss/data loader"
)

if trainer is None:
trainer = pl.Trainer(
**cls.trainer_defaults(
trainable_model=trainable_model, train_dataloader=train_dataloader
)
)

trainer.callbacks.append(CleanupCallback())
trainer.callbacks.append(MetricsCallback())
# Prepare data loaders for training
Expand Down Expand Up @@ -110,3 +121,78 @@ def evaluate(
"""
return evaluator.evaluate(dataset, model)

@staticmethod
def trainer_defaults(
trainable_model: TrainableModel = None,
train_dataloader: SimilarityDataLoader = None,
):
"""Reasonable default parameters for `pytorch_lightning.Trainer`
This function generates parameter set for Trainer, which are considered
"recommended" for most use-cases of Quaterion.
Quaterion similarity learning train process has characteristics that differentiate it from
regular deep learning model training.
This default parameters may be overwritten, if you need some special behaviour for your special task.
Consider overriding default parameters if you need to adjust Trainer parameters:
Example::
trainer_kwargs = Quaterion.trainer_defaults(
trainable_model=model,
train_dataloader=train_dataloader
)
trainer_kwargs['logger'] = pl.loggers.WandbLogger(
name="example_model",
project="example_project",
)
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)
Args:
trainable_model: We will try to adjust default params based on model configuration, if provided
train_dataloader: If provided, trainer params will be adjusted according to dataset
Returns:
kwargs for `pytorch_lightning.Trainer`
"""
use_gpu = torch.cuda.is_available()
defaults = {
"callbacks": [
QuaterionProgressBar(console_kwargs={"tab_size": 4}),
EarlyStopping(f"{TrainStage.VALIDATION}_loss"),
RichModelSummary(max_depth=3),
],
"gpus": int(use_gpu),
"auto_select_gpus": use_gpu,
"max_epochs": -1,
"enable_model_summary": False, # We define our custom model summary
}

# Adjust default parameters according to the dataloader configuration
if train_dataloader:
try:
num_batches = len(train_dataloader)
if num_batches > 0:
defaults["log_every_n_steps"] = min(50, num_batches)
except Exception: # If dataset has to length
pass

# Adjust default parameters according to model configuration
if trainable_model:
# If the cache is enabled and there are no
# trainable encoders - checkpointing on each epoch might become a bottleneck
cache_config = trainable_model.configure_caches()
all_encoders_frozen = all(
not encoder.trainable
for encoder in trainable_model.model.encoders.values()
)
cache_configured = (
cache_config is not None and cache_config.cache_type != CacheType.NONE
)
disable_checkpoints = all_encoders_frozen and cache_configured

if disable_checkpoints:
defaults["enable_checkpointing"] = False
return defaults
2 changes: 2 additions & 0 deletions quaterion/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from quaterion.train.callbacks.cleanup_callback import CleanupCallback
from quaterion.train.callbacks.metrics_callback import MetricsCallback
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 83b901c

Please sign in to comment.