Skip to content

Commit

Permalink
Image prediction using trained model (#2392)
Browse files Browse the repository at this point in the history
* Image prediction using trained model
* Inference on custom images
* Updated the PR following the PEP8 guidelines and made the requested changes
---------
Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
  • Loading branch information
HemanthSai7 authored Jun 9, 2023
1 parent a5376f7 commit a58279c
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions beginner_source/transfer_learning_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory

cudnn.benchmark = True
Expand Down Expand Up @@ -337,6 +338,47 @@ def visualize_model(model, num_images=6):
plt.ioff()
plt.show()


######################################################################
# Inference on custom images
# --------------------------
#
# Use the trained model to make predictions on custom images and visualize
# the predicted class labels along with the images.
#

def visualize_model_predictions(model,img_path):
was_training = model.training
model.eval()

img = Image.open(img_path)
img = data_transforms['val'](img)
img = img.unsqueeze(0)
img = img.to(device)

with torch.no_grad():
outputs = model(img)
_, preds = torch.max(outputs, 1)

ax = plt.subplot(2,2,1)
ax.axis('off')
ax.set_title(f'Predicted: {class_names[preds[0]]}')
imshow(img.cpu().data[0])

model.train(mode=was_training)

######################################################################
#

visualize_model_predictions(
model_conv,
img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
)

plt.ioff()
plt.show()


######################################################################
# Further Learning
# -----------------
Expand Down

0 comments on commit a58279c

Please sign in to comment.