Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Support new Input object in Flash Zero (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Nov 17, 2021
1 parent fca3808 commit 41d97b5
Show file tree
Hide file tree
Showing 20 changed files with 65 additions and 39 deletions.
1 change: 1 addition & 0 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def audio_classification():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("audio_classification_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/audio/speech_recognition/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def speech_recognition():
"trainer.max_epochs": 3,
},
finetune=False,
legacy=True,
)

cli.trainer.save_checkpoint("speech_recognition_model.pt")
Expand Down
37 changes: 29 additions & 8 deletions flash/core/utilities/flash_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from jsonargparse import ArgumentParser
from jsonargparse.signatures import get_class_signature_functions
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.model_helpers import is_overridden

import flash
from flash import DataModule
from flash.core.data.io.input import InputFormat
from flash.core.utilities.lightning_cli import (
class_from_function,
Expand Down Expand Up @@ -117,6 +119,7 @@ def __init__(
default_arguments=None,
finetune=True,
datamodule_attributes=None,
legacy: bool = False,
**kwargs: Any,
) -> None:
"""Flash's extension of the :class:`pytorch_lightning.utilities.cli.LightningCLI`
Expand All @@ -141,6 +144,7 @@ def __init__(
self.additional_datamodule_builders = additional_datamodule_builders or []
self.default_arguments = default_arguments or {}
self.finetune = finetune
self.legacy = legacy

model_class = make_args_optional(model_class, self.datamodule_attributes)
self.local_datamodule_class = datamodule_class
Expand Down Expand Up @@ -185,8 +189,17 @@ def add_arguments_to_parser(self, parser) -> None:
for input in inputs:
if isinstance(input, InputFormat):
input = input.value
if hasattr(self.local_datamodule_class, f"from_{input}"):
self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, f"from_{input}"))
function = f"from_{input}"
if (
(hasattr(self.local_datamodule_class, function) and self.legacy)
or (
hasattr(DataModule, function)
and is_overridden(function, self.local_datamodule_class, DataModule)
and not self.legacy
)
or (not hasattr(DataModule, function) and not self.legacy)
):
self.add_subcommand_from_function(subcommands, getattr(self.local_datamodule_class, function))

for datamodule_builder in self.additional_datamodule_builders:
self.add_subcommand_from_function(subcommands, datamodule_builder)
Expand All @@ -199,13 +212,21 @@ def add_arguments_to_parser(self, parser) -> None:
def add_subcommand_from_function(self, subcommands, function, function_name=None):
subcommand = ArgumentParser()
datamodule_function = class_from_function(drop_kwargs(function), return_type=self.local_datamodule_class)
input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls))
subcommand.add_class_arguments(datamodule_function, fail_untyped=False)
subcommand.add_class_arguments(
input_transform_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, input_transform_function),
)
if self.legacy:
input_transform_function = class_from_function(drop_kwargs(self.local_datamodule_class.input_transform_cls))
subcommand.add_class_arguments(
input_transform_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, input_transform_function),
)
else:
base_datamodule_function = class_from_function(drop_kwargs(self.local_datamodule_class))
subcommand.add_class_arguments(
base_datamodule_function,
fail_untyped=False,
skip=get_overlapping_args(datamodule_function, base_datamodule_function),
)
subcommand_name = function_name or function.__name__
subcommands.add_subcommand(subcommand_name, subcommand)
self._subcommand_builders[subcommand_name] = function
Expand Down
8 changes: 2 additions & 6 deletions flash/graph/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
def from_tu_dataset(
name: str = "KKI",
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> GraphClassificationData:
"""Downloads and loads the TU Dataset."""
from flash.core.utilities.imports import _TORCH_GEOMETRIC_AVAILABLE
Expand All @@ -38,9 +36,7 @@ def from_tu_dataset(
return GraphClassificationData.from_datasets(
train_dataset=dataset,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand Down
1 change: 1 addition & 0 deletions flash/image/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def image_classification():
"trainer.max_epochs": 3,
},
datamodule_attributes={"num_classes", "multi_label"},
legacy=True,
)

cli.trainer.save_checkpoint("image_classification_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/image/detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def object_detection():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("object_detection_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def instance_segmentation():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("instance_segmentation_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def keypoint_detection():
"model.num_keypoints": 1,
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("keypoint_detection_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/image/segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def semantic_segmentation():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("semantic_segmentation_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/image/style_transfer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def style_transfer():
"model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"),
},
finetune=False,
legacy=True,
)

cli.trainer.save_checkpoint("style_transfer_model.pt")
Expand Down
8 changes: 2 additions & 6 deletions flash/pointcloud/detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@


def from_kitti(
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> PointCloudObjectDetectorData:
"""Downloads and loads the KITTI data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
return PointCloudObjectDetectorData.from_folders(
train_folder="data/KITTI_Tiny/Kitti/train",
val_folder="data/KITTI_Tiny/Kitti/val",
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand Down
8 changes: 2 additions & 6 deletions flash/pointcloud/segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,14 @@


def from_kitti(
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> PointCloudSegmentationData:
"""Downloads and loads the semantic KITTI data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/")
return PointCloudSegmentationData.from_folders(
train_folder="data/SemanticKittiTiny/train",
val_folder="data/SemanticKittiTiny/val",
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand Down
11 changes: 4 additions & 7 deletions flash/tabular/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@


def from_titanic(
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
val_split: float = 0.1,
**data_module_kwargs,
) -> TabularClassificationData:
"""Downloads and loads the Titanic data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "./data")
Expand All @@ -32,10 +31,8 @@ def from_titanic(
"Fare",
target_fields="Survived",
train_file="data/titanic/titanic.csv",
val_split=0.1,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
val_split=val_split,
**data_module_kwargs,
)


Expand Down
1 change: 1 addition & 0 deletions flash/tabular/forecasting/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def tabular_forecasting():
},
finetune=False,
datamodule_attributes={"parameters"},
legacy=True,
)

cli.trainer.save_checkpoint("tabular_forecasting_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/text/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def text_classification():
"trainer.max_epochs": 3,
},
datamodule_attributes={"num_classes", "multi_label", "backbone"},
legacy=True,
)

cli.trainer.save_checkpoint("text_classification_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/text/question_answering/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def question_answering():
"trainer.max_epochs": 3,
"model.backbone": "distilbert-base-uncased",
},
legacy=True,
)

cli.trainer.save_checkpoint("question_answering_model.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/text/seq2seq/summarization/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def summarization():
"trainer.max_epochs": 3,
"model.backbone": "sshleifer/distilbart-xsum-1-1",
},
legacy=True,
)

cli.trainer.save_checkpoint("summarization_model_xsum.pt")
Expand Down
1 change: 1 addition & 0 deletions flash/text/seq2seq/translation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def translation():
"trainer.max_epochs": 3,
"model.backbone": "Helsinki-NLP/opus-mt-en-ro",
},
legacy=True,
)

cli.trainer.save_checkpoint("translation_model_en_ro.pt")
Expand Down
8 changes: 2 additions & 6 deletions flash/video/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def from_kinetics(
clip_sampler: str = "uniform",
clip_duration: int = 1,
decode_audio: bool = False,
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> VideoClassificationData:
"""Downloads and loads the Kinetics data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip", "./data")
Expand All @@ -36,9 +34,7 @@ def from_kinetics(
clip_sampler=clip_sampler,
clip_duration=clip_duration,
decode_audio=decode_audio,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand Down
11 changes: 11 additions & 0 deletions tests/tabular/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
from pytorch_lightning import Trainer

from flash.__main__ import main
from flash.core.data.io.input import DataKeys
from flash.core.utilities.imports import _TABULAR_AVAILABLE
from flash.tabular.classification.data import TabularClassificationData
Expand Down Expand Up @@ -117,3 +118,13 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[tabular]'")):
TabularClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(not _TABULAR_TESTING, reason="tabular libraries aren't installed.")
def test_cli():
cli_args = ["flash", "tabular_classification", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
try:
main()
except SystemExit:
pass

0 comments on commit 41d97b5

Please sign in to comment.