Warning
Under development
Go to the folder of this project and run pip install.
cd resnet-pytorch
pip install -e .
This library gives access for three main actions with the resnet-cnn, this actions are
train
, test
and inference
. The demo
folder contains an example of how use it
with a notebook ready to use in colab. Below are some snippets wich explains the code
in the demo folder.
Following code helps you to train resnet. To train is needed to define a CONFIG_PARAMS
constant, this is a dictionary that contains training parameters such as batch size
,
categories
, optimizer
, learning rate
, etc. The train
function receives this
dictionary and gives you the path where the weights were saved as a pt
file.
# Import resnet library previously installed
import resnet
# Define the config params for all proceess
CONFIG_PARAMS = {
'batch_size' : 16,
'categories' : 10,
'optimizer' : 'sgd',
'learning_rate' : 0.001,
'loss' : 'cross-entropy',
'epochs' : 5,
'model_name' : 'resnet',
'path' : 'runs',
'dataset_name' : 'cifar10',
}
# Train the resnet model
weightsPath = resnet.train(params = CONFIG_PARAMS)
Result of this action is the accuracy metric computed for the trained model, this
function receives the params
paramtere and also the weights path
.
# Import resnet library previously installed
import resnet
# Test the ResNet model
accuracy = resnet.test(params = CONFIG_PARAMS,
weightsPath = weightsPath)
Inference receives an image, model and the device as input, and gives you the category of the image. In following example is used PIL to load the image, and some utilities as for loading the model and getting the device.
# Import resnet library previously installed
import resnet
from PIL import Image
# Constat with an image to perform the testing
IMG_PATH = '../gallery/cat.jpeg'
# Getting the main device to perform inference `gpu` by defult.
DEVICE = resnet.utils.getDevice()
# Load model the trained model and image
model = resnet.utils.loadModel(weightsPath = weightsPath,
params = CONFIG_PARAMS,
device = DEVICE)
image = Image.open(IMG_PATH)
# Perform inference (preprocessing and prediction)
results = resnet.inference(image, model, DEVICE)