Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jan 20, 2022
1 parent c3a24ab commit f673ec0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
SampleCollection = None

if _TORCHVISION_AVAILABLE:
import torchvision.transforms.functional as FT
from torchvision.transforms.functional import to_tensor


class SemanticSegmentationInput(Input):
Expand Down Expand Up @@ -102,9 +102,9 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
filepath = sample[DataKeys.INPUT]
sample[DataKeys.INPUT] = FT.to_tensor(image_loader(filepath))
sample[DataKeys.INPUT] = to_tensor(image_loader(filepath))
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = (FT.to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0]
sample[DataKeys.TARGET] = (to_tensor(image_loader(sample[DataKeys.TARGET])) * 255).long()[0]
sample = super().load_sample(sample)
sample[DataKeys.METADATA]["filepath"] = filepath
return sample
Expand Down Expand Up @@ -167,6 +167,6 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
class SemanticSegmentationDeserializer(ImageDeserializer):
def serve_load_sample(self, data: str) -> Dict[str, Any]:
result = super().serve_load_sample(data)
result[DataKeys.INPUT] = FT.to_tensor(result[DataKeys.INPUT])
result[DataKeys.INPUT] = to_tensor(result[DataKeys.INPUT])
result[DataKeys.METADATA] = {"size": result[DataKeys.INPUT].shape[-2:]}
return result

0 comments on commit f673ec0

Please sign in to comment.