Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Nov 17, 2021
1 parent 0ef45b2 commit 10af852
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 22 deletions.
9 changes: 2 additions & 7 deletions flash/image/detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,17 @@

def from_coco_128(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
image_size: Tuple[int, int] = (128, 128),
**input_transform_kwargs,
**data_module_kwargs,
) -> ObjectDetectionData:
"""Downloads and loads the COCO 128 data set."""
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
return ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
image_size=image_size,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand All @@ -49,7 +45,6 @@ def object_detection():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("object_detection_model.pt")
Expand Down
13 changes: 5 additions & 8 deletions flash/image/instance_segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
Expand All @@ -27,10 +27,9 @@
@requires(["image", "icedata"])
def from_pets(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
image_size: Tuple[int, int] = (128, 128),
parser: Optional[Callable] = None,
**input_transform_kwargs,
**data_module_kwargs,
) -> InstanceSegmentationData:
"""Downloads and loads the pets data set from icedata."""
data_dir = icedata.pets.load_data()
Expand All @@ -41,10 +40,9 @@ def from_pets(
return InstanceSegmentationData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
image_size=image_size,
parser=parser,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand All @@ -57,7 +55,6 @@ def instance_segmentation():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("instance_segmentation_model.pt")
Expand Down
12 changes: 5 additions & 7 deletions flash/image/keypoint_detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

from flash.core.utilities.flash_cli import FlashCLI
from flash.core.utilities.imports import _ICEDATA_AVAILABLE, requires
Expand All @@ -26,10 +26,9 @@
@requires("image")
def from_biwi(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
image_size: Tuple[int, int] = (128, 128),
parser: Optional[Callable] = None,
**input_transform_kwargs,
**data_module_kwargs,
) -> KeypointDetectionData:
"""Downloads and loads the BIWI data set from icedata."""
data_dir = icedata.biwi.load_data()
Expand All @@ -40,10 +39,9 @@ def from_biwi(
return KeypointDetectionData.from_folders(
train_folder=data_dir,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
image_size=image_size,
parser=parser,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand Down

0 comments on commit 10af852

Please sign in to comment.