From 3593feb04de45542101a2922943fea77e20d7783 Mon Sep 17 00:00:00 2001 From: guarin <43336610+guarin@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:09:08 +0200 Subject: [PATCH] Use TiCoTransform Everywhere (#1634) * Replace BYOLTransform with TiCoTransform * Remove unused imports --- benchmarks/imagenet/resnet50/tico.py | 11 +++-------- examples/notebooks/pytorch/tico.ipynb | 8 ++------ examples/notebooks/pytorch_lightning/tico.ipynb | 13 ++++--------- .../pytorch_lightning_distributed/tico.ipynb | 13 ++++--------- examples/pytorch/tico.py | 8 ++------ examples/pytorch_lightning/tico.py | 13 ++++--------- examples/pytorch_lightning_distributed/tico.py | 13 ++++--------- 7 files changed, 23 insertions(+), 56 deletions(-) diff --git a/benchmarks/imagenet/resnet50/tico.py b/benchmarks/imagenet/resnet50/tico.py index fa98d53b7..50c522bdd 100644 --- a/benchmarks/imagenet/resnet50/tico.py +++ b/benchmarks/imagenet/resnet50/tico.py @@ -9,12 +9,8 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead -from lightly.models.utils import ( - deactivate_requires_grad, - get_weight_decay_parameters, - update_momentum, -) -from lightly.transforms import BYOLTransform +from lightly.models.utils import get_weight_decay_parameters, update_momentum +from lightly.transforms import TiCoTransform from lightly.utils.benchmarking import OnlineLinearClassifier from lightly.utils.lars import LARS from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule @@ -135,5 +131,4 @@ def configure_optimizers(self): return [optimizer], [scheduler] -# TiCo uses BYOL augmentations. -transform = BYOLTransform() +transform = TiCoTransform() diff --git a/examples/notebooks/pytorch/tico.ipynb b/examples/notebooks/pytorch/tico.ipynb index 54dcdbb34..a9fe8e18c 100644 --- a/examples/notebooks/pytorch/tico.ipynb +++ b/examples/notebooks/pytorch/tico.ipynb @@ -61,11 +61,7 @@ "from lightly.loss.tico_loss import TiCoLoss\n", "from lightly.models.modules.heads import TiCoProjectionHead\n", "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", - "from lightly.transforms.tico_transform import (\n", - " TiCoTransform,\n", - " TiCoView1Transform,\n", - " TiCoView2Transform,\n", - ")\n", + "from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n", "from lightly.utils.scheduler import cosine_schedule" ] }, @@ -135,7 +131,7 @@ "# We disable resizing and gaussian blur for cifar10.\n", "transform = TiCoTransform(\n", " view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", - " view_2_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n", ")\n", "dataset = torchvision.datasets.CIFAR10(\n", " \"datasets/cifar10\", download=True, transform=transform\n", diff --git a/examples/notebooks/pytorch_lightning/tico.ipynb b/examples/notebooks/pytorch_lightning/tico.ipynb index 9d39b1553..266d2b993 100644 --- a/examples/notebooks/pytorch_lightning/tico.ipynb +++ b/examples/notebooks/pytorch_lightning/tico.ipynb @@ -52,11 +52,7 @@ "from lightly.loss.tico_loss import TiCoLoss\n", "from lightly.models.modules.heads import TiCoProjectionHead\n", "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", - "from lightly.transforms.byol_transform import (\n", - " BYOLTransform,\n", - " BYOLView1Transform,\n", - " BYOLView2Transform,\n", - ")\n", + "from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n", "from lightly.utils.scheduler import cosine_schedule" ] }, @@ -126,11 +122,10 @@ "metadata": {}, "outputs": [], "source": [ - "# TiCo uses BYOL augmentations.\n", "# We disable resizing and gaussian blur for cifar10.\n", - "transform = BYOLTransform(\n", - " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", - " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + "transform = TiCoTransform(\n", + " view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n", ")\n", "dataset = torchvision.datasets.CIFAR10(\n", " \"datasets/cifar10\", download=True, transform=transform\n", diff --git a/examples/notebooks/pytorch_lightning_distributed/tico.ipynb b/examples/notebooks/pytorch_lightning_distributed/tico.ipynb index 100881e3d..528bd997d 100644 --- a/examples/notebooks/pytorch_lightning_distributed/tico.ipynb +++ b/examples/notebooks/pytorch_lightning_distributed/tico.ipynb @@ -52,11 +52,7 @@ "from lightly.loss.tico_loss import TiCoLoss\n", "from lightly.models.modules.heads import TiCoProjectionHead\n", "from lightly.models.utils import deactivate_requires_grad, update_momentum\n", - "from lightly.transforms.byol_transform import (\n", - " BYOLTransform,\n", - " BYOLView1Transform,\n", - " BYOLView2Transform,\n", - ")\n", + "from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n", "from lightly.utils.scheduler import cosine_schedule" ] }, @@ -126,11 +122,10 @@ "metadata": {}, "outputs": [], "source": [ - "# TiCo uses BYOL augmentations.\n", "# We disable resizing and gaussian blur for cifar10.\n", - "transform = BYOLTransform(\n", - " view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n", - " view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n", + "transform = TiCoTransform(\n", + " view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n", + " view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n", ")\n", "dataset = torchvision.datasets.CIFAR10(\n", " \"datasets/cifar10\", download=True, transform=transform\n", diff --git a/examples/pytorch/tico.py b/examples/pytorch/tico.py index 841129753..f66b6b3ea 100644 --- a/examples/pytorch/tico.py +++ b/examples/pytorch/tico.py @@ -14,11 +14,7 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.tico_transform import ( - TiCoTransform, - TiCoView1Transform, - TiCoView2Transform, -) +from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform from lightly.utils.scheduler import cosine_schedule @@ -58,7 +54,7 @@ def forward_momentum(self, x): # We disable resizing and gaussian blur for cifar10. transform = TiCoTransform( view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0), - view_2_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0), ) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform diff --git a/examples/pytorch_lightning/tico.py b/examples/pytorch_lightning/tico.py index 1032ffded..c8dcdae46 100644 --- a/examples/pytorch_lightning/tico.py +++ b/examples/pytorch_lightning/tico.py @@ -11,11 +11,7 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.byol_transform import ( - BYOLTransform, - BYOLView1Transform, - BYOLView2Transform, -) +from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform from lightly.utils.scheduler import cosine_schedule @@ -63,11 +59,10 @@ def configure_optimizers(self): model = TiCo() -# TiCo uses BYOL augmentations. # We disable resizing and gaussian blur for cifar10. -transform = BYOLTransform( - view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), - view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +transform = TiCoTransform( + view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0), ) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform diff --git a/examples/pytorch_lightning_distributed/tico.py b/examples/pytorch_lightning_distributed/tico.py index ce4ecf9c2..b5a8a1190 100644 --- a/examples/pytorch_lightning_distributed/tico.py +++ b/examples/pytorch_lightning_distributed/tico.py @@ -11,11 +11,7 @@ from lightly.loss.tico_loss import TiCoLoss from lightly.models.modules.heads import TiCoProjectionHead from lightly.models.utils import deactivate_requires_grad, update_momentum -from lightly.transforms.byol_transform import ( - BYOLTransform, - BYOLView1Transform, - BYOLView2Transform, -) +from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform from lightly.utils.scheduler import cosine_schedule @@ -63,11 +59,10 @@ def configure_optimizers(self): model = TiCo() -# TiCo uses BYOL augmentations. # We disable resizing and gaussian blur for cifar10. -transform = BYOLTransform( - view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0), - view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0), +transform = TiCoTransform( + view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0), + view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0), ) dataset = torchvision.datasets.CIFAR10( "datasets/cifar10", download=True, transform=transform