Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with newer pytorch versions #6

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,24 @@ Formally, REMIND takes an input image and passes it through frozen layers of a n

## Dependencies

:warning::warning: | For unknown reasons, our code does not reproduce results in PyTorch versions greater than PyTorch 1.3.1. Please follow our instructions below to ensure reproducibility.
:---: | :---

We have tested the code with the following packages and versions:
- Python 3.7.6
- PyTorch (GPU) 1.3.1
- torchvision 0.4.2
- NumPy 1.18.5
- Python 3.8.13
- PyTorch (GPU) 1.12.1
- torchvision 0.13.1
- NumPy 1.21.5
- FAISS (CPU) 1.5.2
- CUDA 10.1 (also works with CUDA 10.0)
- Scikit-Learn 0.23.1
- Scipy 1.1.0
- CUDA 10.2 (also works with CUDA 11.3)
- Scikit-Learn 1.0.2
- Scipy 1.7.3
- NVIDIA GPU


We recommend setting up a `conda` environment with these same package versions:
```
conda create -n remind_proj python=3.7
conda create -n remind_proj python=3.8
conda activate remind_proj
conda install numpy=1.18.5
conda install pytorch=1.3.1 torchvision=0.4.2 cudatoolkit=10.1 -c pytorch
conda install numpy=1.21.5
conda install pytorch=1.12.1 torchvision=0.13.1 cudatoolkit=10.2 -c pytorch
conda install faiss-cpu=1.5.2 -c pytorch
```

Expand Down
7 changes: 5 additions & 2 deletions image_classification_experiments/REMINDModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@ def __init__(self, num_classes, classifier_G='ResNet18ClassifyAfterLayer4_1',
self.classifier_G = ModelWrapper(core_model, output_layer_names=[extract_features_from], return_single=True)

# make the optimizer
trainable_params = self.get_trainable_params(self.classifier_F, start_lr)
self.optimizer = optim.SGD(trainable_params, momentum=0.9, weight_decay=weight_decay)
self.optimizer = optim.SGD(self.classifier_F.parameters(), lr= start_lr, momentum=0.9, weight_decay=weight_decay)

# setup lr decay
if lr_mode in ['step_lr_per_class']:
self.lr_scheduler_per_class = {}
self.lr_per_class = {}
for class_ix in range(0, num_classes):
self.lr_scheduler_per_class[class_ix] = optim.lr_scheduler.StepLR(self.optimizer,
step_size=lr_step_size,
gamma=lr_gamma)
self.lr_per_class[class_ix] = start_lr
else:
self.lr_scheduler_per_class = None

Expand Down Expand Up @@ -131,6 +132,7 @@ def fit_incremental_batch(self, curr_loader, latent_dict, pq, rehearsal_ixs=None
for x, y, item_ix in zip(codes, batch_labels, batch_item_ixs):
if self.lr_mode == 'step_lr_per_class' and (ongoing_class is None or ongoing_class != y):
ongoing_class = y
self.optimizer.param_groups[0]['lr'] = self.lr_per_class[int(y)]

if self.use_mixup:
# gather two batches of previous data for mixup and replay
Expand Down Expand Up @@ -256,6 +258,7 @@ def fit_incremental_batch(self, curr_loader, latent_dict, pq, rehearsal_ixs=None
# update lr scheduler
if self.lr_scheduler_per_class is not None:
self.lr_scheduler_per_class[int(y)].step()
self.lr_per_class[int(y)] = self.optimizer.param_groups[0]['lr']

def mixup_data(self, x1, y1, x2, y2, alpha=1.0):
if alpha > 0:
Expand Down
2 changes: 1 addition & 1 deletion image_classification_experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def accuracy(output, target, topk=(1,), output_has_class_ids=False):

res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size).item())
return res

Expand Down