Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Onboard IceVision inputs to new object #973

Merged
merged 8 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 87 additions & 9 deletions flash/core/data/utilities/paths.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union


# Copied from torchvision:
Expand All @@ -17,24 +17,102 @@ def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bo
return filename.lower().endswith(extensions)


def isdir(data: Union[str, Tuple[List[str], List[Any]]]) -> bool:
# Copied from torchvision:
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py#L48
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).

Args:
directory (str): root dataset directory
class_to_idx (Dict[str, int]): dictionary mapping class name to class index
extensions (optional): A list of allowed extensions.
Either extensions or is_valid_file should be passed. Defaults to None.
is_valid_file (optional): A function that takes path of a file
and checks if the file is a valid file
(used to check of corrupt files) both extensions and
is_valid_file should not be passed. Defaults to None.

Raises:
ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.

Returns:
List[Tuple[str, int]]: samples of a form (path_to_sample, class)
"""
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:

def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
return instances


def isdir(path: Any) -> bool:
try:
return os.path.isdir(data)
return os.path.isdir(path)
except TypeError:
# data is not path-like (e.g. it may be a list of paths)
return False


def list_valid_files(data: Union[str, List[str]], valid_extensions: Optional[Tuple[str, ...]] = None):
if isdir(data):
data = [os.path.join(data, file) for file in os.listdir(data)]
def find_classes(dir: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset. Ensures that no class is a subdirectory of another.

Args:
dir: Root directory path.

Returns:
(classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx


def list_valid_files(paths: Union[str, List[str]], valid_extensions: Optional[Tuple[str, ...]] = None) -> List[str]:
"""List the files with a valid extension present in: a single file, a list of files, or a directory.

Args:
paths: A single file, a list of files, or a directory.
valid_extensions: The tuple of valid file extensions.

Returns:
The list of files present in ``paths`` that have a valid extension.
"""
if isdir(paths):
paths = [os.path.join(paths, file) for file in os.listdir(paths)]

if not isinstance(data, list):
data = [data]
if not isinstance(paths, list):
paths = [paths]

if valid_extensions is None:
return paths
return list(
filter(
lambda file: has_file_allowed_extension(file, valid_extensions),
data,
paths,
)
)
66 changes: 30 additions & 36 deletions flash/core/integrations/icevision/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

import numpy as np

from flash.core.data.io.input import DataKeys, LabelsState
from flash.core.data.io.input_base import Input
from flash.core.data.utilities.paths import list_valid_files
from flash.core.integrations.icevision.transforms import from_icevision_record
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.image.data import ImagePathsInput
from flash.image.data import image_loader, IMG_EXTENSIONS, NP_EXTENSIONS

if _ICEVISION_AVAILABLE:
from icevision.core.record import BaseRecord
Expand All @@ -28,52 +30,44 @@
from icevision.parsers.parser import Parser


class IceVisionPathsInput(ImagePathsInput):
def predict_load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
return super().predict_load_data(data, dataset)
class IceVisionInput(Input):
def load_data(
self,
root: str,
ann_file: Optional[str] = None,
parser: Optional[Type["Parser"]] = None,
) -> List[Dict[str, Any]]:
if inspect.isclass(parser) and issubclass(parser, Parser):
parser = parser(ann_file, root)
elif isinstance(parser, Callable):
parser = parser(root)
else:
raise ValueError("The parser must be a callable or an IceVision Parser type.")
self.num_classes = parser.class_map.num_classes
self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(self.num_classes)]))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DataKeys.INPUT: record} for record in records[0]]

def predict_load_data(
self, paths: Union[str, List[str]], ann_file: Optional[str] = None, parser: Optional[Type["Parser"]] = None
) -> List[Dict[str, Any]]:
if parser is not None:
return self.load_data(paths, ann_file, parser)
paths = list_valid_files(paths, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return [{DataKeys.INPUT: path} for path in paths]

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
record = sample[DataKeys.INPUT].load()
return from_icevision_record(record)

def predict_load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if isinstance(sample[DataKeys.INPUT], BaseRecord):
# load the data via IceVision Base Record
return self.load_sample(sample)
# load the data using numpy
filepath = sample[DataKeys.INPUT]
sample = super().load_sample(sample)
image = np.array(sample[DataKeys.INPUT])
image = np.array(image_loader(filepath))

record = BaseRecord([FilepathRecordComponent()])
record.filepath = filepath
record.set_img(image)
record.add_component(ClassMapRecordComponent(task=tasks.detection))
return from_icevision_record(record)


class IceVisionParserInput(IceVisionPathsInput):
def __init__(self, parser: Optional[Type["Parser"]] = None):
super().__init__()
self.parser = parser

def load_data(self, data: Tuple[str, str], dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
if self.parser is not None:
if inspect.isclass(self.parser) and issubclass(self.parser, Parser):
root, ann_file = data
parser = self.parser(ann_file, root)
elif isinstance(self.parser, Callable):
parser = self.parser(data)
else:
raise ValueError("The parser must be a callable or an IceVision Parser type.")
dataset.num_classes = parser.class_map.num_classes
self.set_state(LabelsState([parser.class_map.get_by_id(i) for i in range(dataset.num_classes)]))
records = parser.parse(data_splitter=SingleSplitSplitter())
return [{DataKeys.INPUT: record} for record in records[0]]
raise ValueError("The parser argument must be provided.")

def predict_load_data(self, data: Any, dataset: Optional[Any] = None) -> Sequence[Dict[str, Any]]:
result = super().predict_load_data(data, dataset)
if len(result) == 0:
result = self.load_data(data, dataset)
return result
12 changes: 5 additions & 7 deletions flash/image/detection/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
Expand All @@ -21,19 +22,17 @@

def from_coco_128(
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
image_size: Tuple[int, int] = (128, 128),
**data_module_kwargs,
) -> ObjectDetectionData:
"""Downloads and loads the COCO 128 data set."""
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")
return ObjectDetectionData.from_coco(
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
image_size=image_size,
**data_module_kwargs,
)


Expand All @@ -46,7 +45,6 @@ def object_detection():
default_arguments={
"trainer.max_epochs": 3,
},
legacy=True,
)

cli.trainer.save_checkpoint("object_detection_model.pt")
Expand Down
Loading