Skip to content

Commit

Permalink
add CLI train (#1)
Browse files Browse the repository at this point in the history
* add CLI train
* update req.
* kaggle link

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Aug 19, 2021
1 parent 9b203e0 commit 03727f3
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 26 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Kaggle: COVID Detection
# Kaggle: [COVID-19 Detection](https://www.kaggle.com/c/siim-covid19-detection)

[![CI complete testing](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml)
[![Code formatting](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml)
Expand Down Expand Up @@ -27,10 +27,14 @@ A simple way how to use this basic functions:
! pip install https://github.com/Borda/kaggle_COVID-detection/archive/main.zip
```

### see local notebook

- [COVID19 detection with Flash ⚡](notebooks/COVID-detection-with-Lightning-Flash.ipynb)

### run notebooks in Kaggle

- [COVID199 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash)
- [COVID199 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions)
- [COVID19 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash)
- [COVID19 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions)

### some results

Expand Down
52 changes: 52 additions & 0 deletions kaggle_covid/cli_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import fire
import flash
import torch
from flash.image import ObjectDetectionData, ObjectDetector


def main(
path_dataset: str,
image_size: int = 512,
head: str = "efficientdet",
backbone: str = "tf_d0",
learn_rate: float = 1.5e-5,
batch_size: int = 12,
num_epochs: int = 30
) -> None:
# 1. Create the DataModule
dm = ObjectDetectionData.from_coco(
train_folder=os.path.join(path_dataset, "images", 'train'),
train_ann_file=os.path.join(path_dataset, "covid_train.json"),
val_split=0.1,
batch_size=batch_size,
image_size=image_size,
)

# 2. Build the task
model = ObjectDetector(
head=head,
backbone=backbone,
optimizer=torch.optim.AdamW,
learning_rate=learn_rate,
num_classes=dm.num_classes,
image_size=image_size,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(
max_epochs=num_epochs,
gpus=torch.cuda.device_count(),
precision=16,
accumulate_grad_batches=24,
val_check_interval=0.5,
)
trainer.finetune(model, datamodule=dm, strategy="freeze_unfreeze")

# 3. Save the model!
trainer.save_checkpoint("object_detection_model.pt")


if __name__ == '__main__':
fire.Fire(main)
23 changes: 1 addition & 22 deletions notebooks/COVID-detection-with-Lightning-Flash.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,7 @@
},
"outputs": [],
"source": [
"# ! pip install -qU \"numpy>=1.20\" --no-binary numpy --no-build-isolation\n",
"! pip install -q python-gdcm\n",
"# ! pip install -q pylibjpeg-libjpeg pylibjpeg-openjpeg\n",
"# ! pip install -qU \"pylibjpeg==1.2\" --no-binary :all:\n",
"! pip install -qU pydicom opencv-python-headless # \"torchvision==0.8\" \"torch==1.7\"\n",
"! pip install -q https://github.com/airctic/icevision/archive/refs/heads/master.zip\n",
"! pip install -q kaggle_COVID_detection-*.whl\n",
"! pip install -qU https://github.com/Borda/kaggle_COVID-detection/archive/main.zip\n",
"! pip list | grep torch\n",
"! pip list | grep lightning\n",
"! pip list | grep dicom\n",
Expand Down Expand Up @@ -1632,21 +1626,6 @@
"## Training with Flash"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! rm -rf lightning-flash\n",
"! pip uninstall -y lightning-flash\n",
"! git clone https://github.com/PyTorchLightning/lightning-flash.git\n",
"! cd lightning-flash && git checkout feature/icevision && pip install -q .[image]\n",
"# ! pip install -q https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/feature/icevision.zip#egg=lightning-flash[image]\n",
"! pip uninstall -y fiftyone wandb\n",
"# ! pip install -q effdet"
]
},
{
"cell_type": "code",
"execution_count": 27,
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
torch>=1.6
lightning-flash>=0.4
# lightning-flash[image]>=0.5 # install master till v0.5 is released
git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[audio,image]
python-gdcm
pydicom
opencv-python-headless
pycocotools
fire

0 comments on commit 03727f3

Please sign in to comment.