Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More backend agnostic in cutmix and mixup file #593

Merged
merged 3 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions examples/keras_io/tensorflow/vision/cutmix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: [Sayan Nath](https://twitter.com/sayannath2350)
Converted to Keras Core By: [Piyush Thakur](https://github.com/cosmo3769)
Date created: 2021/06/08
Last modified: 2023/07/18
Last modified: 2023/07/24
Description: Data augmentation with CutMix for image classification on CIFAR-10.
Accelerator: GPU
"""
Expand Down Expand Up @@ -48,10 +48,16 @@

import numpy as np
import pandas as pd
import keras_core as keras
import matplotlib.pyplot as plt
import tensorflow as tf

from keras_core import layers
import keras_core as keras

# TF imports related to tf.data preprocessing
from tensorflow import clip_by_value
cosmo3769 marked this conversation as resolved.
Show resolved Hide resolved
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow.random import gamma as tf_random_gamma

keras.utils.set_random_seed(42)

Expand Down Expand Up @@ -88,7 +94,7 @@
## Define hyperparameters
"""

AUTO = tf.data.AUTOTUNE
AUTO = tf_data.AUTOTUNE
BATCH_SIZE = 32
IMG_SIZE = 32

Expand All @@ -98,9 +104,9 @@


def preprocess_image(image, label):
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf.image.convert_image_dtype(image, tf.float32) / 255.0
label = tf.cast(label, tf.float32)
image = tf_image.resize(image, (IMG_SIZE, IMG_SIZE))
image = tf_image.convert_image_dtype(image, "float32") / 255.0
label = keras.backend.cast(label, dtype="float32")
return image, label


Expand All @@ -109,19 +115,19 @@ def preprocess_image(image, label):
"""

train_ds_one = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
tf_data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(1024)
.map(preprocess_image, num_parallel_calls=AUTO)
)
train_ds_two = (
tf.data.Dataset.from_tensor_slices((x_train, y_train))
tf_data.Dataset.from_tensor_slices((x_train, y_train))
.shuffle(1024)
.map(preprocess_image, num_parallel_calls=AUTO)
)

train_ds_simple = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds_simple = tf_data.Dataset.from_tensor_slices((x_train, y_train))

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))

train_ds_simple = (
train_ds_simple.map(preprocess_image, num_parallel_calls=AUTO)
Expand All @@ -130,7 +136,7 @@ def preprocess_image(image, label):
)

# Combine two shuffled datasets from the same training data.
train_ds = tf.data.Dataset.zip((train_ds_one, train_ds_two))
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))

test_ds = (
test_ds.map(preprocess_image, num_parallel_calls=AUTO)
Expand All @@ -146,28 +152,29 @@ def preprocess_image(image, label):


def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0)
gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)


@tf.function
def get_box(lambda_value):
cut_rat = tf.math.sqrt(1.0 - lambda_value)
cut_rat = keras.ops.sqrt(1.0 - lambda_value)

cut_w = IMG_SIZE * cut_rat # rw
cut_w = tf.cast(cut_w, tf.int32)
cut_w = keras.backend.cast(cut_w, "int32")

cut_h = IMG_SIZE * cut_rat # rh
cut_h = tf.cast(cut_h, tf.int32)
cut_h = keras.backend.cast(cut_h, "int32")

cut_x = tf.random.uniform((1,), minval=0, maxval=IMG_SIZE, dtype=tf.int32) # rx
cut_y = tf.random.uniform((1,), minval=0, maxval=IMG_SIZE, dtype=tf.int32) # ry
cut_x = keras.random.random.uniform((1,), minval=0, maxval=IMG_SIZE) # rx
cut_x = keras.backend.cast(cut_x, "int32")
cut_y = keras.random.random.uniform((1,), minval=0, maxval=IMG_SIZE) # ry
cut_y = keras.backend.cast(cut_y, "int32")

boundaryx1 = tf.clip_by_value(cut_x[0] - cut_w // 2, 0, IMG_SIZE)
boundaryy1 = tf.clip_by_value(cut_y[0] - cut_h // 2, 0, IMG_SIZE)
bbx2 = tf.clip_by_value(cut_x[0] + cut_w // 2, 0, IMG_SIZE)
bby2 = tf.clip_by_value(cut_y[0] + cut_h // 2, 0, IMG_SIZE)
boundaryx1 = clip_by_value(cut_x[0] - cut_w // 2, 0, IMG_SIZE)
boundaryy1 = clip_by_value(cut_y[0] - cut_h // 2, 0, IMG_SIZE)
bbx2 = clip_by_value(cut_x[0] + cut_w // 2, 0, IMG_SIZE)
bby2 = clip_by_value(cut_y[0] + cut_h // 2, 0, IMG_SIZE)

target_h = bby2 - boundaryy1
if target_h == 0:
Expand All @@ -180,7 +187,6 @@ def get_box(lambda_value):
return boundaryx1, boundaryy1, target_h, target_w


@tf.function
def cutmix(train_ds_one, train_ds_two):
(image1, label1), (image2, label2) = train_ds_one, train_ds_two

Expand All @@ -197,19 +203,19 @@ def cutmix(train_ds_one, train_ds_two):
boundaryx1, boundaryy1, target_h, target_w = get_box(lambda_value)

# Get a patch from the second image (`image2`)
crop2 = tf.image.crop_to_bounding_box(
crop2 = tf_image.crop_to_bounding_box(
image2, boundaryy1, boundaryx1, target_h, target_w
)
# Pad the `image2` patch (`crop2`) with the same offset
image2 = tf.image.pad_to_bounding_box(
image2 = tf_image.pad_to_bounding_box(
crop2, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
)
# Get a patch from the first image (`image1`)
crop1 = tf.image.crop_to_bounding_box(
crop1 = tf_image.crop_to_bounding_box(
image1, boundaryy1, boundaryx1, target_h, target_w
)
# Pad the `image1` patch (`crop1`) with the same offset
img1 = tf.image.pad_to_bounding_box(
img1 = tf_image.pad_to_bounding_box(
crop1, boundaryy1, boundaryx1, IMG_SIZE, IMG_SIZE
)

Expand All @@ -221,7 +227,7 @@ def cutmix(train_ds_one, train_ds_two):

# Adjust Lambda in accordance to the pixel ration
lambda_value = 1 - (target_w * target_h) / (IMG_SIZE * IMG_SIZE)
lambda_value = tf.cast(lambda_value, tf.float32)
lambda_value = keras.backend.cast(lambda_value, "float32")

# Combine the labels of both images
label = lambda_value * label1 + (1 - lambda_value) * label2
Expand Down Expand Up @@ -371,7 +377,7 @@ def training_model():

In this example, we trained our model for 15 epochs.
In our experiment, the model with CutMix achieves a better accuracy on the CIFAR-10 dataset
(76.92% in our experiment) compared to the model that doesn't use the augmentation (72.23%).
(77.34% in our experiment) compared to the model that doesn't use the augmentation (66.90%).
You may notice it takes less time to train the model with the CutMix augmentation.

You can experiment further with the CutMix technique by following the
Expand Down
37 changes: 21 additions & 16 deletions examples/keras_io/tensorflow/vision/mixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: MixUp augmentation for image classification
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/03/06
Last modified: 2021/03/06
Last modified: 2023/07/24
Description: Data augmentation using the mixup technique for image classification.
Accelerator: GPU
"""
Expand Down Expand Up @@ -37,10 +37,15 @@
"""

import numpy as np
import tensorflow as tf
import keras_core as keras
import matplotlib.pyplot as plt

from keras_core import layers
import keras_core as keras

# TF imports related to tf.data preprocessing
from tensorflow import data as tf_data
cosmo3769 marked this conversation as resolved.
Show resolved Hide resolved
from tensorflow import image as tf_image
from tensorflow.random import gamma as tf_random_gamma

"""
## Prepare the dataset
Expand All @@ -53,17 +58,17 @@

x_train = x_train.astype("float32") / 255.0
x_train = np.reshape(x_train, (-1, 28, 28, 1))
y_train = tf.one_hot(y_train, 10)
y_train = keras.ops.one_hot(y_train, 10)

x_test = x_test.astype("float32") / 255.0
x_test = np.reshape(x_test, (-1, 28, 28, 1))
y_test = tf.one_hot(y_test, 10)
y_test = keras.ops.one_hot(y_test, 10)

"""
## Define hyperparameters
"""

AUTO = tf.data.AUTOTUNE
AUTO = tf_data.AUTOTUNE
BATCH_SIZE = 64
EPOCHS = 10

Expand All @@ -77,22 +82,22 @@
new_x_train, new_y_train = x_train[val_samples:], y_train[val_samples:]

train_ds_one = (
tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train))
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)
train_ds_two = (
tf.data.Dataset.from_tensor_slices((new_x_train, new_y_train))
tf_data.Dataset.from_tensor_slices((new_x_train, new_y_train))
.shuffle(BATCH_SIZE * 100)
.batch(BATCH_SIZE)
)
# Because we will be mixing up the images and their corresponding labels, we will be
# combining two shuffled datasets from the same training data.
train_ds = tf.data.Dataset.zip((train_ds_one, train_ds_two))
train_ds = tf_data.Dataset.zip((train_ds_one, train_ds_two))

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)
val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val)).batch(BATCH_SIZE)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)
test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test)).batch(BATCH_SIZE)

"""
## Define the mixup technique function
Expand All @@ -105,21 +110,21 @@


def sample_beta_distribution(size, concentration_0=0.2, concentration_1=0.2):
gamma_1_sample = tf.random.gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf.random.gamma(shape=[size], alpha=concentration_0)
gamma_1_sample = tf_random_gamma(shape=[size], alpha=concentration_1)
gamma_2_sample = tf_random_gamma(shape=[size], alpha=concentration_0)
return gamma_1_sample / (gamma_1_sample + gamma_2_sample)


def mix_up(ds_one, ds_two, alpha=0.2):
# Unpack two datasets
images_one, labels_one = ds_one
images_two, labels_two = ds_two
batch_size = tf.shape(images_one)[0]
batch_size = keras.backend.shape(images_one)[0]

# Sample lambda and reshape it to do the mixup
l = sample_beta_distribution(batch_size, alpha, alpha)
x_l = tf.reshape(l, (batch_size, 1, 1, 1))
y_l = tf.reshape(l, (batch_size, 1))
x_l = keras.ops.reshape(l, (batch_size, 1, 1, 1))
y_l = keras.ops.reshape(l, (batch_size, 1))

# Perform mixup on both images and labels by combining a pair of images/labels
# (one from each dataset) into one image/label
Expand Down