Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Feb 14, 2022
1 parent b570bef commit cd3a563
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
5 changes: 4 additions & 1 deletion flash/image/classification/integrations/baal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def train_dataloader(self) -> "DataLoader":
if self.has_labelled_data and self.val_split:
self.val_dataloader = self._val_dataloader

return self.labelled.train_dataloader()
if self.has_labelled_data:
return self.labelled.train_dataloader()
# Return a dummy dataloader, will be replaced by the loop
return DataLoader(["dummy"])

def _val_dataloader(self) -> "DataLoader":
self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1]
Expand Down
1 change: 1 addition & 0 deletions flash/image/classification/integrations/baal/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _reset_dataloader_for_stage(self, running_state: RunningStage):
if is_overridden(dataloader_name, self.trainer.datamodule)
else None
)

if dataloader:
if _PL_GREATER_EQUAL_1_5_0:
setattr(
Expand Down
2 changes: 1 addition & 1 deletion flash/image/face_detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
learning_rate=learning_rate,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
output_transform=FaceDetectionOutputTransform,
output_transform=FaceDetectionOutputTransform(),
)

@staticmethod
Expand Down
3 changes: 1 addition & 2 deletions flash/image/face_detection/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,5 @@ def per_batch_transform(batch: Any) -> Any:
# preds: list of torch.Tensor(N, 5) as x1, y1, x2, y2, score
preds = [preds[preds[:, 5] == batch_idx, :5] for batch_idx in range(len(preds))]
preds = ff.utils.preprocess.adjust_results(preds, scales, paddings)
batch[DataKeys.PREDS] = preds

return batch
return preds

0 comments on commit cd3a563

Please sign in to comment.