Skip to content

Latest commit

 

History

History
55 lines (53 loc) · 3.71 KB

README.md

File metadata and controls

55 lines (53 loc) · 3.71 KB

README

This project is an image classifier that uses pre-trained models such as VGG16 and Resnet18 to classify images of different classes. The user can specify the architecture of the model, the learning rate, the number of hidden units, the number of training epochs, and whether to use GPU for training when running the script.

Computer Vision

Dependencies

  • torch
  • torchvision
  • argparse
  • matplotlib
  • numpy
  • PIL

Usage

The script can be run using the following command:

python train.py dataset_folder
  • dataset_folder: directory containing the training data
  • save_dir (optional): directory to save checkpoints
  • arch (optional): model architecture, can be either "vgg16" or "resnet18"
  • learning_rate (optional): learning rate for the optimizer
  • hidden_units (optional): number of hidden units for the classifier
  • epochs (optional): number of training epochs
  • gpu (optional): flag to use GPU for training

Data Preparation

The script expects the dataset to be structured as follows:

dataset_folder | |__ train | | | |__ class1 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ class2 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ ... | |__ valid | | | |__ class1 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ class2 | | | | | |__ image1.jpg | | |__ image2.jpg | | |__ ... | | | |__ ... | |__ test | |__ class1 | | | |__ image1.jpg | |__ image2.jpg | |__ ... | |__ class2 | | | |__ image1.jpg | |__ image2.jpg | |__ ... | |__ ...

Output

The script will output the number of images in each dataset and the classes. It will also save the trained model to the specified save_dir with the name checkpoint.pth.

It will also display the loss and accuracy of the model after each epoch.