A Vision Transformer model that is capable of predicting 325 species of birds. Model is trained on this dataset.
Training code can be found here. This model is the deit_tiny_patch16_224
model. It requires an input size of 224x224.
More information about DeIT can be found here.
Dependencies:
Install dependencies with pip install torch timm
import torch
import timm # Required dependency for loading model
model = torch.hub.load("SharanSMenon/birds-325-model", "birds_325_deit_tiny_patch16_224")
import json
from urllib.request import urlopen
URL = "https://raw.githubusercontent.com/SharanSMenon/birds-325-model/main/classes.json"
response = urlopen(URL)
classes = json.loads(response.read())
PIL
and torchvision
are required dependencies for inference.
from PIL import Image
import torchvision.transforms as T
test_transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
image = Image.open("painted-bunting.jpg")
transformed = test_transform(image)
batch = transformed.unsqueeze(0)
with torch.no_grad():
output = model(batch)
prediction = classes[output.argmax(dim=1).item()]
print(prediction)