Skip to content

mrahimpour/Tutorial-on-Knowledge-Distillation-with-Keras

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

71 Commits
 
 

Repository files navigation

Tutorial on Knowledge Distillation with Keras

This tutorial aims to introduce the concept of knowledge distillation and its application in medical image processing.

What is knowledge distillation:

Knowledge distillation refers to any type of training which transfers knowledge from a cumbersome model to a simple or compact model. It was introduced first by Bucilua et al. [1] as a model compression method to compress a large ensemble model into a more compact one. The method was proven to lower the computation time and memory requirements, while maintaining the performance. Later, this concept was further used in the “Teacher-Student” setting, where the student model is trained to mimic the knowledge acquired by the teacher model such that it produces the same predictions. This is accomplished by using the class probabilities produced by softmax layer of the teacher model as the target values to train the student model by minimizing the cross-entropy objective function.

However, small values of teacher's class probabilities have only a small contribution to the cross-entropy function. While these small values can represent valuable information about the similarity structure in the data, they will have limited impact on the training of the student network. To overcome this disadvantage, Be and Caruana [2] proposed to use logit values as targets instead of the final prediction values of the teacher model. Logit values are defined as the pre-softmax activations and therefore contain logarithmic relationship between prediction probabilities. Instead of using logit values, Hinton et al. [3] proposed the concept of temperature to soften the target values of the teacher and to provide a better representation of smaller probabilities in the output values. In addition to optimizing the student model with the ground truth labels, their proposed objective function penalizes the student model based on the softened version of teacher output as follows:

CodeCogsEqn

Where CodeCogsEqn is the prediction of student model, CodeCogsEqn is the label representing the ground truth and H is the cross-entropy function. CodeCogsEqn and CodeCogsEqn refer to the logit values produced by teacher and student model respectively while CodeCogsEqn is the temperature parameter and CodeCogsEqn is the softmax function and CodeCogsEqn is the parameter that balance the effect of hard labels provided by reference ground truth and teacher's soft labels.

]]

Implementation:

In order to setup a “Teacher-Student” framework, we need to create two separate CNNs for teacher and student models. A CNN with any arbitrary architecture can be used as the teacher and student models. It is contrary to the transfer learning approach where both models must have the same architecture to be able to copy the weights of the pre-trained model to the new model. The following steps are required to implement the knowledge distillation:

  1. Train a teacher model

  2. Create a student model and train it by knowledge distillation (KD) loss

    2.1. Prepare the data

    2.2. Create the student model

    2.3. Define the KD loss function

    2.4. Compile and train the student-teacher model

    2.5. Evaluate the student model on the test dataset

  3. Train a student from the scratch as a baseline model for comparison

Train a teacher model:

Initially, we create a U-Net as a teacher model. Any CNN model can be used to define the models. To learn more about the implemetation of U-Net model, check out the following link: https://github.com/JeroenBertels/deepvoxnet2/tree/main/deepvoxnet2/keras/models

teacher_model = create_unet_like_model(number_input_features,
                                       subsample_factors_per_pathway,
                                       kernel_sizes_per_pathway,
                                       number_features_per_pathway,
                                       output_size,
                                       padding,
                                       upsampling,
                                       l1_regularization,
                                       l2_regularization)

objective_function = K.binary_crossentropy
metrics_=[loss.binary_crossentropy, loss.binary_dice]

teacher_model.compile(loss = objective_function,
                      optimizer = SGD(lr=initial_learning_rate, momentum=0.9, nesterov=True) if optimizer == "SGD" else Adam(lr=initial_learning_rate),
                      metrics = metrics_)
teacher_model.fit(training_x, training_y)

Preparing the dataset:

The data needs to be divided into training and testing datasets, then they should be normalized, our data are 3D images.

def intensityNormalization(I):
     min = np.min(I, axis=(0, 1, 2))
     max = np.max(I, axis=(0, 1, 2))
     I = (I - min) / (max - min)
     return I

Create the student model:

student_model = create_unet_like_model(number_input_features,
                                       subsample_factors_per_pathway,
                                       kernel_sizes_per_pathway,
                                       number_features_per_pathway,
                                       output_size,
                                       padding,
                                       upsampling,
                                       l1_regularization,
                                       l2_regularization)

Define the KD Loss:

The distillation loss function is a weighted combination of soft dice loss between the prediction of student model ('output_student') and ground truth and the cross-entopy loss between the softened output of student model ('soft_preds') and teacher model. The loss fuunction for the distillation framework is defined as a dictionary as follows:

loss_function = {'output_student':soft_dice, 'soft_preds':KD}

Where KD is defined as:

import keras.backend as K

def KD(y_true, y_pred):
     
  d = K.int_shape(y_pred)[-1]
  ld = y_pred[...,:K.cast(d/2, 'int32')]
  rd = y_pred[...,K.cast(d/2, 'int32'):] 
  kd_loss = (α *(T*T))*K.binary_crossentropy(ld, rd)

  return kd_loss

Define, compile and train the student-teacher model

temperature = 5

logit_student = student_model.get_layer('activation_out_student').output 
hard_out_student = Activation('softmax')(logit_student)

soft_out_student  = Lambda(lambda x: x/temperature)(logit_student)
soft_prob_student = Activation('softmax')(soft_out_student)

logit_teacher = teacher_model.get_layer('activation_out_teacher').output 
soft_out_teacher = Lambda(lambda x: x/temperature)(logit_teacher)
soft_prob_teacher = Activation('softmax')(soft_out_teacher)

concat_out_layer_student_teacher = Concatenate(axis = -1, name = 'soft_outputss')([soft_prob_teacher, soft_prob_student])

student_teacher_model = Model(student_model.input, [student_model.output, concat_last_layer_student_teacher])

metrics_=[loss.binary_crossentropy, loss.binary_dice]
metrics_dict = {'output_student': metrics_,'soft_preds': [])}

student_teacher_model.compile(loss = loss_function,
                              optimizer = SGD(lr=initial_learning_rate, momentum=0.9, nesterov=True) if optimizer == "SGD" else Adam(lr=initial_learning_rate),
                              metrics = metrics_dict)

Evaluate the student model on the test dataset

student_model.evaluate(x_test, y_test) #Enriched student model

Train a student from the scratch as a baseline model for comparison

student_model = create_unet_like_model(number_input_features,
                                       subsample_factors_per_pathway,
                                       kernel_sizes_per_pathway,
                                       number_features_per_pathway,
                                       output_size,
                                       padding,
                                       upsampling,
                                       l1_regularization,
                                       l2_regularization)
objective_function = K.binary_crossentropy
metrics_=[loss.binary_crossentropy, loss.binary_dice]

student_model.compile(loss = objective_function,
                      optimizer = SGD(lr=initial_learning_rate, momentum=0.9, nesterov=True) if optimizer == "SGD" else Adam(lr=initial_learning_rate),
                      metrics = metrics_)
student_model.fit(training_x, training_y)

student_model.evaluate(x_test, y_test) #student model trained from scratch

Acknowledgements

This project has received funding from the European Union’s Horizon 2020 research and innovation programme under the Marie Skłodowska-Curie grant agreement No 764458 (https://www.hybrid2020.eu/).

References:

[1] Bucilua, C., Caruana, R., Niculescu-Mizil, A.: Model compression. In: Proceedingsof the 12th ACM SIGKDD international conference on Knowledge discovery anddata mining. pp. 535–541 (2006)

[2] Ba, J., Caruana, R.: Do deep nets really need to be deep? In: Advances in neuralinformation processing systems. pp. 2654–2662 (2014)

[3] Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network.arXiv preprint arXiv:1503.02531 (2015)

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published