Skip to content

Commit

Permalink
works on cifar10 with mixup
Browse files Browse the repository at this point in the history
  • Loading branch information
fpaissan committed Nov 27, 2023
1 parent a3b4a26 commit 70c5ff9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
6 changes: 1 addition & 5 deletions recipes/image_classification/cifar10.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
accelerate launch train.py cfg/phinet.py -b 64 --dataset torch/cifar10 --num-classes 10 \
--model phinet -epochs 100 --amp \
-lr 0.005 --weight-decay 0.02 \
--experiment_name cifar10 \
--alpha 3 --beta 0.75 --t_zero 6 --num_layers 7
# --hflip 0.5 # --aa rand-m3-mstd0.55 --mixup 0.1 --bce-loss \
--hflip 0.5 --aa rand-m3-mstd0.55 --mixup 0.1 --experiment_name cifar10 --bce-loss
24 changes: 14 additions & 10 deletions recipes/image_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ def compute_loss(self, pred, batch):
return self.criterion(pred[0], pred[1])

def configure_optimizers(self):
opt = torch.optim.Adam(self.modules.parameters(), lr=1e-2, weight_decay=0.0005)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=5000, eta_min=1e-7
)
return opt, sched
opt = torch.optim.Adam(self.modules.parameters(), lr=3e-4, weight_decay=0.0005)
# sched = torch.optim.lr_scheduler.CosineAnnealingLR(
# opt, T_max=5000, eta_min=1e-7
# )
return opt


def top_k_accuracy( k=1):
Expand All @@ -121,9 +121,13 @@ def top_k_accuracy( k=1):
Top-K accuracy.
"""
def acc(pred, batch):
if pred[1].ndim == 2:
target = pred[1].argmax(1)
else:
target = pred[1]
_, indices = torch.topk(pred[0], k, dim=1)
correct = torch.sum(indices == pred[1].view(-1, 1))
accuracy = correct.item() / pred[1].size(0)
correct = torch.sum(indices == target.view(-1, 1))
accuracy = correct.item() / target.size(0)

return torch.Tensor([accuracy]).to(pred[0].device)

Expand All @@ -144,13 +148,13 @@ def acc(pred, batch):

mind = ImageClassification(hparams=hparams)

top1 = mm.Metric("top1_acc", top_k_accuracy(k=1))
top5 = mm.Metric("top5_acc", top_k_accuracy(k=5))
top1 = mm.Metric("top1_acc", top_k_accuracy(k=1), eval_only=True)
top5 = mm.Metric("top5_acc", top_k_accuracy(k=5), eval_only=True)

mind.train(
epochs=100,
datasets={"train": train_loader, "val": val_loader},
metrics=[top1, top5],
metrics=[top5, top1],
checkpointer=checkpointer,
debug=hparams.debug,
)
Expand Down

0 comments on commit 70c5ff9

Please sign in to comment.