From 8cc1bc4cdda96aeda7e757d7fd19aada421ba56c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Feb 2022 12:56:33 +0000 Subject: [PATCH] Fixes --- flash/video/classification/data.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 50007b0d7d..d74c7aaca6 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -181,7 +181,6 @@ def from_files( >>> _ = [os.remove(f"predict_video_{i}.mp4") for i in range(1, 4)] """ ds_kw = dict( - target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -197,6 +196,7 @@ def from_files( train_targets, transform=train_transform, video_sampler=video_sampler, + target_formatter=target_formatter, **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -357,7 +357,6 @@ def from_folders( >>> shutil.rmtree("predict_folder") """ ds_kw = dict( - target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -368,7 +367,12 @@ def from_folders( ) train_input = input_cls( - RunningStage.TRAINING, train_folder, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + train_folder, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -546,7 +550,6 @@ def from_data_frame( >>> del predict_data_frame """ ds_kw = dict( - target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -562,7 +565,12 @@ def from_data_frame( predict_data = (predict_data_frame, input_field, predict_videos_root, predict_resolver) train_input = input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + *train_data, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -754,7 +762,6 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( - target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -770,7 +777,12 @@ def from_csv( predict_data = (predict_file, input_field, predict_videos_root, predict_resolver) train_input = input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TRAINING, + *train_data, + transform=train_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None) @@ -917,7 +929,6 @@ def from_fiftyone( >>> del predict_dataset """ ds_kw = dict( - target_formatter=target_formatter, transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -933,6 +944,7 @@ def from_fiftyone( transform=train_transform, video_sampler=video_sampler, label_field=label_field, + target_formatter=target_formatter, **ds_kw, ) target_formatter = getattr(train_input, "target_formatter", None)