Skip to content

Commit

Permalink
Heavily improved checkpoint system
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathunky committed Oct 30, 2023
1 parent e71d332 commit eaf00e9
Showing 1 changed file with 63 additions and 48 deletions.
111 changes: 63 additions & 48 deletions u2net_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,45 @@

bce_loss = nn.BCELoss(reduction="mean")

train_configs = {
"plain_resized": {
"name": "Plain Images",
"message": "Learning the dataset itself...",
"transform": [Resize(512), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"flipped_v": {
"name": "Vertical Flips",
"message": "Learning the vertical flips of dataset images...",
"transform": [Resize(512), VerticalFlip(), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"flipped_h": {
"name": "Horizontal Flips",
"message": "Learning the horizontal flips of dataset images...",
"transform": [Resize(512), HorizontalFlip(), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"rotated_l": {
"name": "Left Rotations",
"message": "Learning the left rotations of dataset images...",
"transform": [Resize(512), Rotation(90), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"rotated_r": {
"name": "Right Rotations",
"message": "Learning the right rotation of dataset images...",
"transform": [Resize(512), Rotation(270), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"random_crops": {
"name": "Random Crops",
"message": "Augmenting dataset with random crops...",
"transform": [Resize(2304), RandomCrop(256), ToTensorLab(flag=0)],
"batch_factor": 4, # because they are smaller => we can fit more in memory
},
}


def dice_loss(pred, target, smooth=1.0):
pred = pred.contiguous()
Expand Down Expand Up @@ -76,42 +115,42 @@ def get_args():
"-p",
"--plain_resized",
type=int,
default=1,
default=5,
help="Number of training epochs for plain_resized target.",
)
parser.add_argument(
"-vf",
"--vflipped",
type=int,
default=1,
default=2,
help="Number of training epochs for flipped_v target.",
)
parser.add_argument(
"-hf",
"--hflipped",
type=int,
default=1,
default=2,
help="Number of training epochs for flipped_h target.",
)
parser.add_argument(
"-left",
"--rotated_l",
type=int,
default=1,
default=2,
help="Number of training epochs for rotated_l target.",
)
parser.add_argument(
"-right",
"--rotated_r",
type=int,
default=1,
default=2,
help="Number of training epochs for rotated_r target.",
)
parser.add_argument(
"-r",
"--rand",
type=int,
default=5,
default=20,
help="Number of training epochs for random_crops target.",
)

Expand Down Expand Up @@ -162,9 +201,7 @@ def load_checkpoint(net, optimizer, filename="saved_models/checkpoint.pth.tar"):
net.load_state_dict(checkpoint["state"]["state_dict"])
optimizer.load_state_dict(checkpoint["state"]["optimizer"])
training_counts = checkpoint["state"]["training_counts"]
print(
f"Loading checkpoint '{filename}' with training counts: {training_counts}..."
)
print(f"Loading checkpoint '{filename}'...")
return training_counts
else:
print(f"No checkpoint file found at '{filename}'. Starting from scratch...")
Expand Down Expand Up @@ -222,7 +259,6 @@ def train_model(net, optimizer, scheduler, dataloader, device):
for i, data in enumerate(dataloader):
inputs = data["image"].to(device)
labels = data["label"].to(device)
gc.collect()
optimizer.zero_grad()

outputs = net(inputs)
Expand Down Expand Up @@ -263,18 +299,18 @@ def train_epochs(
epoch_loss = train_model(net, optimizer, scheduler, dataloader, device)
print(f"Loss per epoch: {epoch_loss}\n")

if epoch == 0:
print(
f"Expected performance is {time.time() - start_time:.2f} seconds per epoch.\n"
)
if sum(training_counts.values()) == 3:
elapsed_time = time.time() - start_time
minutes, seconds = divmod(elapsed_time, 60)
print(f"Expected performance is {minutes:.1f} minutes seconds per epoch.\n")

# Increment the corresponding training count
training_counts[key] += 1

# Saves model every save_frq iterations or during the last one
if (epoch + 1) % SAVE_FRQ == 0 or epoch + 1 == len(epochs):
# in ONNX format! ^_^ UwU
save_model_as_onnx(net, device, epoch + 1)
save_model_as_onnx(net, device, sum(training_counts.values()))

# Saves checkpoint every check_frq epochs or during the last one
if (epoch + 1) % CHECK_FRQ == 0 or epoch + 1 == len(epochs):
Expand Down Expand Up @@ -305,8 +341,8 @@ def main():

targets = {
"plain_resized": args.plain_resized,
"flipped_v": args.vflipped,
"flipped_h": args.hflipped,
"flipped_v": args.vflipped,
"rotated_l": args.rotated_l,
"rotated_r": args.rotated_r,
"random_crops": args.rand,
Expand All @@ -333,6 +369,14 @@ def main():
)

training_counts = load_checkpoint(net, optimizer)
# dealing with negative values, if model was trained for more epochs than in target:
for key in training_counts:
if targets[key] < training_counts[key]:
targets[key] = training_counts[key]
print(
f"Task: {train_configs[key]['name']:<17} Epochs done: {training_counts[key]}/{targets[key]}"
)

print("---\n")

scheduler = CosineAnnealingLR(optimizer, T_max=sum(targets.values()), eta_min=1e-6)
Expand All @@ -354,38 +398,6 @@ def create_and_train(transform, batch_size, epochs, train_type):
)

# Configuration dictionary
train_configs = {
"plain_resized": {
"message": "Learning the dataset itself...",
"transform": [Resize(512), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"flipped_h": {
"message": "Learning the horizontal flips of dataset images...",
"transform": [Resize(512), HorizontalFlip(), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"flipped_v": {
"message": "Learning the vertical flips of dataset images...",
"transform": [Resize(512), VerticalFlip(), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"rotated_l": {
"message": "Learning the left rotations of dataset images...",
"transform": [Resize(512), Rotation(90), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"rotated_r": {
"message": "Learning the right rotation of dataset images...",
"transform": [Resize(512), Rotation(270), ToTensorLab(flag=0)],
"batch_factor": 1,
},
"random_crops": {
"message": "Augmenting dataset with random crops...",
"transform": [Resize(2304), RandomCrop(256), ToTensorLab(flag=0)],
"batch_factor": 4, # because they are smaller => we can fit more in memory
},
}

# Training loop
for train_type, config in train_configs.items():
Expand All @@ -399,6 +411,9 @@ def create_and_train(transform, batch_size, epochs, train_type):
)
training_counts[train_type] = targets[train_type]

if sum(difference.values()) < 1:
print("Nothing left to do!")


if __name__ == "__main__":
main()

0 comments on commit eaf00e9

Please sign in to comment.