Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new: add quaterion progress bar, add trainer defaults, move callbacks… #114

Merged
merged 11 commits into from
Jun 8, 2022
15 changes: 1 addition & 14 deletions examples/cars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,9 @@ def train(
batch_size=batch_size, input_size=input_size, shuffle=shuffle
)

early_stopping = EarlyStopping(
monitor="validation_loss",
patience=50,
)

trainer = pl.Trainer(
gpus=1 if torch.cuda.is_available() else 0,
max_epochs=epochs,
callbacks=[early_stopping, ModelSummary(max_depth=3)],
enable_checkpointing=False,
log_every_n_steps=1,
)

Quaterion.fit(
trainable_model=model,
trainer=trainer,
trainer=None,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
)
Expand Down
76 changes: 56 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ loguru = "^0.5.3"
mmh3 = "^3.0.0"
pytorch-metric-learning = {version = "^1.3.0", optional = true}
protobuf = ">= 3.9.2, <3.20" # until https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 is fixed
rich = "^12.4.4"
torchmetrics = "<=0.8.2" # Intil warning with `full_state_update` is fixed


[tool.poetry.dev-dependencies]
Expand Down
108 changes: 97 additions & 11 deletions quaterion/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import warnings

import torch
from typing import Optional, Union, Sized, Iterable, Dict

import pytorch_lightning as pl
import torch
import warnings
from pytorch_lightning.callbacks import EarlyStopping, RichModelSummary
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import Dataset
from quaterion_models import SimilarityModel
from typing import Optional, Union, Sized, Iterable, Dict

from quaterion.dataset.similarity_data_loader import (
PairsSimilarityDataLoader,
Expand All @@ -15,9 +13,12 @@
)
from quaterion.eval.evaluator import Evaluator
from quaterion.loss import GroupLoss, PairwiseLoss
from quaterion.train.cleanup_callback import CleanupCallback
from quaterion.train.metrics_callback import MetricsCallback
from quaterion.train.cache import CacheType
from quaterion.train.callbacks import CleanupCallback, MetricsCallback
from quaterion.train.trainable_model import TrainableModel
from quaterion.utils.enums import TrainStage
from quaterion.utils.progress_bar import QuaterionProgressBar
from quaterion_models import SimilarityModel


class Quaterion:
Expand All @@ -27,7 +28,7 @@ class Quaterion:
def fit(
cls,
trainable_model: TrainableModel,
trainer: pl.Trainer,
trainer: Optional[pl.Trainer],
train_dataloader: SimilarityDataLoader,
val_dataloader: Optional[SimilarityDataLoader] = None,
ckpt_path: Optional[str] = None,
Expand All @@ -38,8 +39,11 @@ def fit(

Args:
trainable_model: model to fit
trainer: `pytorch_lightning.Trainer` instance to handle fitting routine
internally
trainer:
`pytorch_lightning.Trainer` instance to handle fitting routine internally.
If `None` passed, trainer will be created with :meth:`Quaterion.trainer_defaults`.
The default parameters are intended to serve as a quick start for learning the model, and we
encourage users to try different parameters if the default ones do not give a satisfactory result.
train_dataloader: DataLoader instance to retrieve samples during training
stage
val_dataloader: Optional DataLoader instance to retrieve samples during
Expand All @@ -62,6 +66,13 @@ def fit(
"Try other loss/data loader"
)

if trainer is None:
trainer = pl.Trainer(
**cls.trainer_defaults(
trainable_model=trainable_model, train_dataloader=train_dataloader
)
)

trainer.callbacks.append(CleanupCallback())
trainer.callbacks.append(MetricsCallback())
# Prepare data loaders for training
Expand Down Expand Up @@ -110,3 +121,78 @@ def evaluate(

"""
return evaluator.evaluate(dataset, model)

@staticmethod
def trainer_defaults(
trainable_model: TrainableModel = None,
train_dataloader: SimilarityDataLoader = None,
):
"""Reasonable default parameters for `pytorch_lightning.Trainer`

This function generates parameter set for Trainer, which are considered
"recommended" for most use-cases of Quaterion.
Quaterion similarity learning train process has characteristics that differentiate it from
regular deep learning model training.
This default parameters may be overwritten, if you need some special behaviour for your special task.

Consider overriding default parameters if you need to adjust Trainer parameters:

Example::

trainer_kwargs = Quaterion.trainer_defaults(
trainable_model=model,
train_dataloader=train_dataloader
)
trainer_kwargs['logger'] = pl.loggers.WandbLogger(
name="example_model",
project="example_project",
)
trainer_kwargs['callbacks'].append(YourCustomCallback())
trainer = pl.Trainer(**trainer_kwargs)

Args:
trainable_model: We will try to adjust default params based on model configuration, if provided
train_dataloader: If provided, trainer params will be adjusted according to dataset

Returns:
kwargs for `pytorch_lightning.Trainer`
"""
use_gpu = torch.cuda.is_available()
defaults = {
"callbacks": [
QuaterionProgressBar(console_kwargs={"tab_size": 4}),
EarlyStopping(f"{TrainStage.VALIDATION}_loss"),
RichModelSummary(max_depth=3),
],
"gpus": int(use_gpu),
"auto_select_gpus": use_gpu,
"max_epochs": -1,
"enable_model_summary": False, # We define our custom model summary
}

# Adjust default parameters according to the dataloader configuration
if train_dataloader:
try:
num_batches = len(train_dataloader)
if num_batches > 0:
defaults["log_every_n_steps"] = min(50, num_batches)
except Exception: # If dataset has to length
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe specify exception? (TypeError, NotImplementedError)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

though about it, but I don't trust this list used in Lightning, the whole section is optional, so I am if it fails for whatever reason

pass

# Adjust default parameters according to model configuration
if trainable_model:
# If the cache is enabled and there are no
# trainable encoders - checkpointing on each epoch might become a bottleneck
cache_config = trainable_model.configure_caches()
all_encoders_frozen = all(
not encoder.trainable
for encoder in trainable_model.model.encoders.values()
)
cache_configured = (
cache_config is not None and cache_config.cache_type != CacheType.NONE
)
disable_checkpoints = all_encoders_frozen and cache_configured

if disable_checkpoints:
defaults["enable_checkpointing"] = False
return defaults
2 changes: 2 additions & 0 deletions quaterion/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from quaterion.train.callbacks.cleanup_callback import CleanupCallback
from quaterion.train.callbacks.metrics_callback import MetricsCallback
Loading