Skip to content

Commit

Permalink
Use TiCoTransform Everywhere (#1634)
Browse files Browse the repository at this point in the history
* Replace BYOLTransform with TiCoTransform
* Remove unused imports
  • Loading branch information
guarin committed Aug 16, 2024
1 parent 222d97b commit 3593feb
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 56 deletions.
11 changes: 3 additions & 8 deletions benchmarks/imagenet/resnet50/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,8 @@

from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import (
deactivate_requires_grad,
get_weight_decay_parameters,
update_momentum,
)
from lightly.transforms import BYOLTransform
from lightly.models.utils import get_weight_decay_parameters, update_momentum
from lightly.transforms import TiCoTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.lars import LARS
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule
Expand Down Expand Up @@ -135,5 +131,4 @@ def configure_optimizers(self):
return [optimizer], [scheduler]


# TiCo uses BYOL augmentations.
transform = BYOLTransform()
transform = TiCoTransform()
8 changes: 2 additions & 6 deletions examples/notebooks/pytorch/tico.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@
"from lightly.loss.tico_loss import TiCoLoss\n",
"from lightly.models.modules.heads import TiCoProjectionHead\n",
"from lightly.models.utils import deactivate_requires_grad, update_momentum\n",
"from lightly.transforms.tico_transform import (\n",
" TiCoTransform,\n",
" TiCoView1Transform,\n",
" TiCoView2Transform,\n",
")\n",
"from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n",
"from lightly.utils.scheduler import cosine_schedule"
]
},
Expand Down Expand Up @@ -135,7 +131,7 @@
"# We disable resizing and gaussian blur for cifar10.\n",
"transform = TiCoTransform(\n",
" view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n",
")\n",
"dataset = torchvision.datasets.CIFAR10(\n",
" \"datasets/cifar10\", download=True, transform=transform\n",
Expand Down
13 changes: 4 additions & 9 deletions examples/notebooks/pytorch_lightning/tico.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@
"from lightly.loss.tico_loss import TiCoLoss\n",
"from lightly.models.modules.heads import TiCoProjectionHead\n",
"from lightly.models.utils import deactivate_requires_grad, update_momentum\n",
"from lightly.transforms.byol_transform import (\n",
" BYOLTransform,\n",
" BYOLView1Transform,\n",
" BYOLView2Transform,\n",
")\n",
"from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n",
"from lightly.utils.scheduler import cosine_schedule"
]
},
Expand Down Expand Up @@ -126,11 +122,10 @@
"metadata": {},
"outputs": [],
"source": [
"# TiCo uses BYOL augmentations.\n",
"# We disable resizing and gaussian blur for cifar10.\n",
"transform = BYOLTransform(\n",
" view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n",
"transform = TiCoTransform(\n",
" view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n",
")\n",
"dataset = torchvision.datasets.CIFAR10(\n",
" \"datasets/cifar10\", download=True, transform=transform\n",
Expand Down
13 changes: 4 additions & 9 deletions examples/notebooks/pytorch_lightning_distributed/tico.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,7 @@
"from lightly.loss.tico_loss import TiCoLoss\n",
"from lightly.models.modules.heads import TiCoProjectionHead\n",
"from lightly.models.utils import deactivate_requires_grad, update_momentum\n",
"from lightly.transforms.byol_transform import (\n",
" BYOLTransform,\n",
" BYOLView1Transform,\n",
" BYOLView2Transform,\n",
")\n",
"from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform\n",
"from lightly.utils.scheduler import cosine_schedule"
]
},
Expand Down Expand Up @@ -126,11 +122,10 @@
"metadata": {},
"outputs": [],
"source": [
"# TiCo uses BYOL augmentations.\n",
"# We disable resizing and gaussian blur for cifar10.\n",
"transform = BYOLTransform(\n",
" view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),\n",
"transform = TiCoTransform(\n",
" view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),\n",
" view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),\n",
")\n",
"dataset = torchvision.datasets.CIFAR10(\n",
" \"datasets/cifar10\", download=True, transform=transform\n",
Expand Down
8 changes: 2 additions & 6 deletions examples/pytorch/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.tico_transform import (
TiCoTransform,
TiCoView1Transform,
TiCoView2Transform,
)
from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -58,7 +54,7 @@ def forward_momentum(self, x):
# We disable resizing and gaussian blur for cifar10.
transform = TiCoTransform(
view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
Expand Down
13 changes: 4 additions & 9 deletions examples/pytorch_lightning/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -63,11 +59,10 @@ def configure_optimizers(self):

model = TiCo()

# TiCo uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
transform = TiCoTransform(
view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
Expand Down
13 changes: 4 additions & 9 deletions examples/pytorch_lightning_distributed/tico.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
from lightly.loss.tico_loss import TiCoLoss
from lightly.models.modules.heads import TiCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.byol_transform import (
BYOLTransform,
BYOLView1Transform,
BYOLView2Transform,
)
from lightly.transforms import TiCoTransform, TiCoView1Transform, TiCoView2Transform
from lightly.utils.scheduler import cosine_schedule


Expand Down Expand Up @@ -63,11 +59,10 @@ def configure_optimizers(self):

model = TiCo()

# TiCo uses BYOL augmentations.
# We disable resizing and gaussian blur for cifar10.
transform = BYOLTransform(
view_1_transform=BYOLView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=BYOLView2Transform(input_size=32, gaussian_blur=0.0),
transform = TiCoTransform(
view_1_transform=TiCoView1Transform(input_size=32, gaussian_blur=0.0),
view_2_transform=TiCoView2Transform(input_size=32, gaussian_blur=0.0),
)
dataset = torchvision.datasets.CIFAR10(
"datasets/cifar10", download=True, transform=transform
Expand Down

0 comments on commit 3593feb

Please sign in to comment.