Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CLI train #1

Merged
merged 6 commits into from
Aug 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Kaggle: COVID Detection
# Kaggle: [COVID-19 Detection](https://www.kaggle.com/c/siim-covid19-detection)

[![CI complete testing](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml)
[![Code formatting](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml)
Expand Down Expand Up @@ -27,10 +27,14 @@ A simple way how to use this basic functions:
! pip install https://github.com/Borda/kaggle_COVID-detection/archive/main.zip
```

### see local notebook

- [COVID19 detection with Flash ⚡](notebooks/COVID-detection-with-Lightning-Flash.ipynb)

### run notebooks in Kaggle

- [COVID199 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash)
- [COVID199 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions)
- [COVID19 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash)
- [COVID19 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions)

### some results

Expand Down
52 changes: 52 additions & 0 deletions kaggle_covid/cli_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import fire
import flash
import torch
from flash.image import ObjectDetectionData, ObjectDetector


def main(
path_dataset: str,
image_size: int = 512,
head: str = "efficientdet",
backbone: str = "tf_d0",
learn_rate: float = 1.5e-5,
batch_size: int = 12,
num_epochs: int = 30
) -> None:
# 1. Create the DataModule
dm = ObjectDetectionData.from_coco(
train_folder=os.path.join(path_dataset, "images", 'train'),
train_ann_file=os.path.join(path_dataset, "covid_train.json"),
val_split=0.1,
batch_size=batch_size,
image_size=image_size,
)

# 2. Build the task
model = ObjectDetector(
head=head,
backbone=backbone,
optimizer=torch.optim.AdamW,
learning_rate=learn_rate,
num_classes=dm.num_classes,
image_size=image_size,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=num_epochs,
gpus=torch.cuda.device_count(),
precision=16,
accumulate_grad_batches=24,
val_check_interval=0.5,
)
trainer.finetune(model, datamodule=dm, strategy="freeze_unfreeze")

# 3. Save the model!
trainer.save_checkpoint("object_detection_model.pt")


if __name__ == '__main__':
fire.Fire(main)
23 changes: 1 addition & 22 deletions notebooks/COVID-detection-with-Lightning-Flash.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@
},
"outputs": [],
"source": [
"# ! pip install -qU \"numpy>=1.20\" --no-binary numpy --no-build-isolation\n",
"! pip install -q python-gdcm\n",
"# ! pip install -q pylibjpeg-libjpeg pylibjpeg-openjpeg\n",
"# ! pip install -qU \"pylibjpeg==1.2\" --no-binary :all:\n",
"! pip install -qU pydicom opencv-python-headless # \"torchvision==0.8\" \"torch==1.7\"\n",
"! pip install -q https://github.com/airctic/icevision/archive/refs/heads/master.zip\n",
"! pip install -q kaggle_COVID_detection-*.whl\n",
"! pip install -qU https://github.com/Borda/kaggle_COVID-detection/archive/main.zip\n",
"! pip list | grep torch\n",
"! pip list | grep lightning\n",
"! pip list | grep dicom\n",
Expand Down Expand Up @@ -1632,21 +1626,6 @@
"## Training with Flash"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! rm -rf lightning-flash\n",
"! pip uninstall -y lightning-flash\n",
"! git clone https://github.com/PyTorchLightning/lightning-flash.git\n",
"! cd lightning-flash && git checkout feature/icevision && pip install -q .[image]\n",
"# ! pip install -q https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/feature/icevision.zip#egg=lightning-flash[image]\n",
"! pip uninstall -y fiftyone wandb\n",
"# ! pip install -q effdet"
]
},
{
"cell_type": "code",
"execution_count": 27,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
torch>=1.6
lightning-flash>=0.4
# lightning-flash[image]>=0.5 # install master till v0.5 is released
git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[audio,image]
python-gdcm
pydicom
opencv-python-headless
pycocotools
fire