This repo is me trying to explore the Vision Transformer (ViT) model from google, using Pytorch.
The original paper can be found here.
- Having a python>=3.8 environment is recommended
- To install the dependencies, run:
pip install -r requirements.txt
- modules.py: Contains the implementation of the different modules/sub-networks used in the ViT model.
- main.py: Contains the code to test the model against the official implementation using timm module
- test.py: Contains the code to test the model with a sample image using the Coco weights, after running main.py
- inspect.py: Contains the code to inspect the model, layer by layer, after running main.py
-
Run
python main.py
to load the official model to the one we built. This will save the model weights in the./data
folder. -
Run
python test.py
to test the model with a sample image. This will print the top 5 predictions, I'm using a cat image feel free to change it. -
Run
python inspect.py
to inspect the model, layer by layer. This will print the output of each layer.
- Implement the ViT model
- Test the model against the official implementation
- Test the model with a sample image
- Inspect the model, layer by layer
- Train the model on a custom dataset
- Add a demo to test the model on a custom image