diff --git a/beginner_source/transfer_learning_tutorial.py b/beginner_source/transfer_learning_tutorial.py index f08312522c..7a2b053763 100644 --- a/beginner_source/transfer_learning_tutorial.py +++ b/beginner_source/transfer_learning_tutorial.py @@ -44,6 +44,7 @@ import matplotlib.pyplot as plt import time import os +from PIL import Image from tempfile import TemporaryDirectory cudnn.benchmark = True @@ -337,6 +338,47 @@ def visualize_model(model, num_images=6): plt.ioff() plt.show() + +###################################################################### +# Inference on custom images +# -------------------------- +# +# Use the trained model to make predictions on custom images and visualize +# the predicted class labels along with the images. +# + +def visualize_model_predictions(model,img_path): + was_training = model.training + model.eval() + + img = Image.open(img_path) + img = data_transforms['val'](img) + img = img.unsqueeze(0) + img = img.to(device) + + with torch.no_grad(): + outputs = model(img) + _, preds = torch.max(outputs, 1) + + ax = plt.subplot(2,2,1) + ax.axis('off') + ax.set_title(f'Predicted: {class_names[preds[0]]}') + imshow(img.cpu().data[0]) + + model.train(mode=was_training) + +###################################################################### +# + +visualize_model_predictions( + model_conv, + img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg' +) + +plt.ioff() +plt.show() + + ###################################################################### # Further Learning # -----------------