Skip to content

ResNet implemented in pytorch ready to train, test and inference

Notifications You must be signed in to change notification settings

pQbas/resnet-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Warning

Under development

resnet-pytorch

ResNet CNN implemented in PyTorch, ready for training, testing, and inference.

How install?

Go to the folder of this project and run pip install.

cd resnet-pytorch
pip install -e .

How use it?

This library gives access for three main actions with the resnet-cnn, this actions are train, test and inference. The demo folder contains an example of how use it with a notebook ready to use in colab. Below are some snippets wich explains the code in the demo folder.

Train action

Following code helps you to train resnet. To train is needed to define a CONFIG_PARAMS constant, this is a dictionary that contains training parameters such as batch size, categories, optimizer, learning rate, etc. The train function receives this dictionary and gives you the path where the weights were saved as a pt file.

# Import resnet library previously installed
import resnet

# Define the config params for all proceess
CONFIG_PARAMS = {
    'batch_size'    : 16,
    'categories'    : 10,
    'optimizer'     : 'sgd',
    'learning_rate' : 0.001,
    'loss'          : 'cross-entropy',
    'epochs'        : 5,
    'model_name'    : 'resnet',
    'path'          : 'runs',
    'dataset_name'  : 'cifar10',
}

# Train the resnet model
weightsPath = resnet.train(params = CONFIG_PARAMS)

Test action

Result of this action is the accuracy metric computed for the trained model, this function receives the params paramtere and also the weights path.

# Import resnet library previously installed
import resnet

# Test the ResNet model
accuracy = resnet.test(params      = CONFIG_PARAMS, 
                       weightsPath = weightsPath)

Inference action

Inference receives an image, model and the device as input, and gives you the category of the image. In following example is used PIL to load the image, and some utilities as for loading the model and getting the device.

# Import resnet library previously installed
import resnet
from PIL import Image

# Constat with an image to perform the testing
IMG_PATH = '../gallery/cat.jpeg'

# Getting the main device to perform inference `gpu` by defult.
DEVICE = resnet.utils.getDevice()

# Load model the trained model and image 
model = resnet.utils.loadModel(weightsPath = weightsPath, 
                               params      = CONFIG_PARAMS, 
                               device      = DEVICE)
image = Image.open(IMG_PATH)

# Perform inference (preprocessing and prediction)
results = resnet.inference(image, model, DEVICE)

About

ResNet implemented in pytorch ready to train, test and inference

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages