Github repo to improve classification performance by exploiting segment information. Here the project web page
To install the project, simply clone the repository and get the necessary dependencies:
git clone https://github.com/MarcoParola/improve_classifier_via_segment.git
cd improve_classifier_via_segment
Create the virtualenv (you can also use conda) and install the dependencies of requirements.txt
python -m venv env
. env/bin/activate
python -m pip install -r requirements.txt
mkdir data
Then you can download the oral coco-dataset (both images and json file) from TODO-put-link. Copy them into data
folder and unzip the file oral1.zip
.
Next, create a new project on Weights & Biases named improve_classifier_via_segment
. Edit entity
parameter in config.yaml by sett. Log in and paste your API key when prompted.
wandb login
Here is a quick overview of the main use of the repo. Further information is available in the official doc.
Classification on the whole dataset:
- Train CNN classifier on the whole dataset
- Test CNN classifier on the whole dataset
Specify the pre-trained classification model by setting model.weights
.
classification_mode=whole
specifies we are solving the classification without exploiting the segment information.
# TRAIN classifier on whole images
python train.py task=c classification_mode=whole model.weights=ConvNeXt_Small_Weights.DEFAULT
# TEST classifier whole images
python test.py task=c classification_mode=whole checkpoint.version=123
Classification on the masked dataset:
- Train CNN for segmentation
- test CNN for segmentation
- train CNN classifier on the masked dataset
- test CNN classifier on the masked dataset
Specify the pre-trained segmentation model by setting model_seg
. classification_mode=masked
specifies we are solving the classification by exploiting the segment information.
The first step of this task is to train a segmentation NN that will be used to generate masks for images in the next step.
# TRAIN segmentation NN
python train.py task=s model_seg='fcn'
# TEST segmentation NN
python test.py task=s model_seg='fcn' checkpoint.version=123
After training your segmentation NN insert the version of the model you want to exploit in the masked classification in the __init__
method of src/data/masked_classification/dataset.py
.
Specify the pre-trained classification model by setting model.weights
. Specify the segmentation model previously trained for generate the masks by setting model_seg
.
# TRAIN classifier on masked images
python train.py task=c classification_mode=masked model.weights=ConvNeXt_Small_Weights.DEFAULT model_seg='fcn' sgm_type='soft'
# TEST classifier on masked images
python test.py task=c classification_mode=masked model_seg='fcn' checkpoint.version=123
Classification on the whole dataset exploiting saliency maps and masks:
- train CNN classifier on the original dataset with backpropagating saliency map error
- test CNN classifier on the whole dataset
Specify the pre-trained classification model by setting model.weights
.
classification_mode=saliency
specifies we are solving the classification by exploiting the saliency map information.
# TRAIN classifier on whole images with saliency map information
python train.py task=c classification_mode=saliency model.weights=ConvNeXt_Small_Weights.DEFAULT
# TEST classifier on whole images with saliency map information
python test.py task=c classification_mode=saliency checkpoint.version=123
python -m tensorboard.main --logdir=logs