From 03727f3ba4fbed889e4668db80fc4ad580e78dcc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 19 Aug 2021 16:32:58 +0200 Subject: [PATCH] add CLI train (#1) * add CLI train * update req. * kaggle link Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- README.md | 10 ++-- kaggle_covid/cli_main.py | 52 +++++++++++++++++++ ...COVID-detection-with-Lightning-Flash.ipynb | 23 +------- requirements.txt | 4 +- 4 files changed, 63 insertions(+), 26 deletions(-) create mode 100644 kaggle_covid/cli_main.py diff --git a/README.md b/README.md index 08c212d..5f453ee 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Kaggle: COVID Detection +# Kaggle: [COVID-19 Detection](https://www.kaggle.com/c/siim-covid19-detection) [![CI complete testing](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/ci_testing.yml) [![Code formatting](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml/badge.svg?branch=main&event=push)](https://github.com/Borda/kaggle_COVID-detection/actions/workflows/code-format.yml) @@ -27,10 +27,14 @@ A simple way how to use this basic functions: ! pip install https://github.com/Borda/kaggle_COVID-detection/archive/main.zip ``` +### see local notebook + +- [COVID19 detection with Flash ⚡](notebooks/COVID-detection-with-Lightning-Flash.ipynb) + ### run notebooks in Kaggle -- [COVID199 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash) -- [COVID199 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions) +- [COVID19 detection with Flash ⚡](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash) +- [COVID19 detection - predictions](https://www.kaggle.com/jirkaborovec/covid-detection-with-lightning-flash-predictions) ### some results diff --git a/kaggle_covid/cli_main.py b/kaggle_covid/cli_main.py new file mode 100644 index 0000000..189b815 --- /dev/null +++ b/kaggle_covid/cli_main.py @@ -0,0 +1,52 @@ +import os + +import fire +import flash +import torch +from flash.image import ObjectDetectionData, ObjectDetector + + +def main( + path_dataset: str, + image_size: int = 512, + head: str = "efficientdet", + backbone: str = "tf_d0", + learn_rate: float = 1.5e-5, + batch_size: int = 12, + num_epochs: int = 30 +) -> None: + # 1. Create the DataModule + dm = ObjectDetectionData.from_coco( + train_folder=os.path.join(path_dataset, "images", 'train'), + train_ann_file=os.path.join(path_dataset, "covid_train.json"), + val_split=0.1, + batch_size=batch_size, + image_size=image_size, + ) + + # 2. Build the task + model = ObjectDetector( + head=head, + backbone=backbone, + optimizer=torch.optim.AdamW, + learning_rate=learn_rate, + num_classes=dm.num_classes, + image_size=image_size, + ) + + # 3. Create the trainer and finetune the model + trainer = flash.Trainer( + max_epochs=num_epochs, + gpus=torch.cuda.device_count(), + precision=16, + accumulate_grad_batches=24, + val_check_interval=0.5, + ) + trainer.finetune(model, datamodule=dm, strategy="freeze_unfreeze") + + # 3. Save the model! + trainer.save_checkpoint("object_detection_model.pt") + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/notebooks/COVID-detection-with-Lightning-Flash.ipynb b/notebooks/COVID-detection-with-Lightning-Flash.ipynb index a01d684..578925e 100644 --- a/notebooks/COVID-detection-with-Lightning-Flash.ipynb +++ b/notebooks/COVID-detection-with-Lightning-Flash.ipynb @@ -49,13 +49,7 @@ }, "outputs": [], "source": [ - "# ! pip install -qU \"numpy>=1.20\" --no-binary numpy --no-build-isolation\n", - "! pip install -q python-gdcm\n", - "# ! pip install -q pylibjpeg-libjpeg pylibjpeg-openjpeg\n", - "# ! pip install -qU \"pylibjpeg==1.2\" --no-binary :all:\n", - "! pip install -qU pydicom opencv-python-headless # \"torchvision==0.8\" \"torch==1.7\"\n", - "! pip install -q https://github.com/airctic/icevision/archive/refs/heads/master.zip\n", - "! pip install -q kaggle_COVID_detection-*.whl\n", + "! pip install -qU https://github.com/Borda/kaggle_COVID-detection/archive/main.zip\n", "! pip list | grep torch\n", "! pip list | grep lightning\n", "! pip list | grep dicom\n", @@ -1632,21 +1626,6 @@ "## Training with Flash" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "! rm -rf lightning-flash\n", - "! pip uninstall -y lightning-flash\n", - "! git clone https://github.com/PyTorchLightning/lightning-flash.git\n", - "! cd lightning-flash && git checkout feature/icevision && pip install -q .[image]\n", - "# ! pip install -q https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/feature/icevision.zip#egg=lightning-flash[image]\n", - "! pip uninstall -y fiftyone wandb\n", - "# ! pip install -q effdet" - ] - }, { "cell_type": "code", "execution_count": 27, diff --git a/requirements.txt b/requirements.txt index eddc880..bece9b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,8 @@ torch>=1.6 -lightning-flash>=0.4 +# lightning-flash[image]>=0.5 # install master till v0.5 is released +git+https://github.com/PyTorchLightning/lightning-flash.git#egg=lightning-flash[audio,image] python-gdcm pydicom opencv-python-headless pycocotools +fire