From db3f1832095875704d575b2a60b58203ade50c20 Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sat, 9 Jul 2022 13:59:43 -0700 Subject: [PATCH] Datamodule cleanup (#657) * Cleaning up preprocessing methods across DataModules * Decoupled deleting the bbox with the other transforms in the GeoDataset DataModules * Cleaning up how channel standardization is done * Changing default conf for ETCI2021 and fixing So2Sat * Forgot to update the indices. * Change to use Normalize * Remove default augs --- conf/etci2021.yaml | 2 +- tests/conf/etci2021.yaml | 2 +- torchgeo/datamodules/chesapeake.py | 41 +++++++----- torchgeo/datamodules/cowc.py | 12 ++-- torchgeo/datamodules/cyclone.py | 21 +++---- torchgeo/datamodules/etci2021.py | 18 +++--- torchgeo/datamodules/fair1m.py | 5 +- torchgeo/datamodules/inria.py | 27 ++++---- torchgeo/datamodules/landcoverai.py | 7 ++- torchgeo/datamodules/loveda.py | 3 +- torchgeo/datamodules/naip.py | 27 +++++--- torchgeo/datamodules/nasa_marine_debris.py | 5 +- torchgeo/datamodules/oscd.py | 11 ++-- torchgeo/datamodules/resisc45.py | 1 - torchgeo/datamodules/sen12ms.py | 11 ++-- torchgeo/datamodules/so2sat.py | 73 ++++++++-------------- torchgeo/datamodules/ucmerced.py | 10 +-- torchgeo/datamodules/usavars.py | 26 +++++--- 18 files changed, 149 insertions(+), 153 deletions(-) diff --git a/conf/etci2021.yaml b/conf/etci2021.yaml index 3db44fb8fc3..6fdcf92778e 100644 --- a/conf/etci2021.yaml +++ b/conf/etci2021.yaml @@ -7,7 +7,7 @@ experiment: encoder_weights: "imagenet" learning_rate: 1e-3 learning_rate_schedule_patience: 6 - in_channels: 7 + in_channels: 6 num_classes: 2 ignore_index: 0 datamodule: diff --git a/tests/conf/etci2021.yaml b/tests/conf/etci2021.yaml index 722cfb35402..54e3dc2b629 100644 --- a/tests/conf/etci2021.yaml +++ b/tests/conf/etci2021.yaml @@ -7,7 +7,7 @@ experiment: encoder_weights: null learning_rate: 1e-3 learning_rate_schedule_patience: 6 - in_channels: 7 + in_channels: 6 num_classes: 2 ignore_index: 0 datamodule: diff --git a/torchgeo/datamodules/chesapeake.py b/torchgeo/datamodules/chesapeake.py index 639c8a2f645..74eded7e09d 100644 --- a/torchgeo/datamodules/chesapeake.py +++ b/torchgeo/datamodules/chesapeake.py @@ -169,24 +169,34 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - sample["image"] = sample["image"] / 255.0 - sample["mask"] = sample["mask"].squeeze() + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + + if "mask" in sample: + sample["mask"] = sample["mask"].squeeze() + if self.use_prior_labels: + sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) + sample["mask"] = F.normalize( + sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 + ) + else: + if self.class_set == 5: + sample["mask"][sample["mask"] == 5] = 4 + sample["mask"][sample["mask"] == 6] = 4 + sample["mask"] = sample["mask"].long() - if self.use_prior_labels: - sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0) - sample["mask"] = F.normalize( - sample["mask"] + self.prior_smoothing_constant, p=1, dim=0 - ) - else: - if self.class_set == 5: - sample["mask"][sample["mask"] == 5] = 4 - sample["mask"][sample["mask"] == 6] = 4 - sample["mask"] = sample["mask"].long() + return sample - sample["image"] = sample["image"].float() + def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Removes the bounding box property from a sample. - del sample["bbox"] + Args: + sample: dictionary with geographic metadata + Returns + sample without the bbox property + """ + del sample["bbox"] return sample def nodata_check( @@ -240,6 +250,7 @@ def setup(self, stage: Optional[str] = None) -> None: self.center_crop(self.patch_size), self.nodata_check(self.patch_size), self.preprocess, + self.remove_bbox, ] ) val_transforms = Compose( @@ -247,12 +258,14 @@ def setup(self, stage: Optional[str] = None) -> None: self.center_crop(self.patch_size), self.nodata_check(self.patch_size), self.preprocess, + self.remove_bbox, ] ) test_transforms = Compose( [ self.pad_to(self.original_patch_size, image_value=0, mask_value=0), self.preprocess, + self.remove_bbox, ] ) diff --git a/torchgeo/datamodules/cowc.py b/torchgeo/datamodules/cowc.py index 44743cacec0..736c13a35b7 100644 --- a/torchgeo/datamodules/cowc.py +++ b/torchgeo/datamodules/cowc.py @@ -42,7 +42,7 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: @@ -51,8 +51,10 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - sample["image"] = sample["image"] / 255.0 # scale to [0, 1] - sample["label"] = sample["label"].float() + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 # scale to [0, 1] + if "label" in sample: + sample["label"] = sample["label"].float() return sample def prepare_data(self) -> None: @@ -73,10 +75,10 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ train_val_dataset = COWCCounting( - self.root_dir, split="train", transforms=self.custom_transform + self.root_dir, split="train", transforms=self.preprocess ) self.test_dataset = COWCCounting( - self.root_dir, split="test", transforms=self.custom_transform + self.root_dir, split="test", transforms=self.preprocess ) self.train_dataset, self.val_dataset = random_split( train_val_dataset, diff --git a/torchgeo/datamodules/cyclone.py b/torchgeo/datamodules/cyclone.py index 85e4e494111..9a9724234b6 100644 --- a/torchgeo/datamodules/cyclone.py +++ b/torchgeo/datamodules/cyclone.py @@ -51,7 +51,7 @@ def __init__( self.num_workers = num_workers self.api_key = api_key - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: @@ -60,11 +60,13 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - sample["image"] = sample["image"] / 255.0 # scale to [0,1] + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 sample["image"] = ( sample["image"].unsqueeze(0).repeat(3, 1, 1) - ) # convert to 3 channel - sample["label"] = torch.as_tensor(sample["label"]).float() + ) # convert from grayscale to 3 channel + if "label" in sample: + sample["label"] = torch.as_tensor(sample["label"]).float() return sample @@ -77,7 +79,6 @@ def prepare_data(self) -> None: TropicalCycloneWindEstimation( self.root_dir, split="train", - transforms=self.custom_transform, download=self.api_key is not None, api_key=self.api_key, ) @@ -99,17 +100,11 @@ def setup(self, stage: Optional[str] = None) -> None: stage: stage to set up """ self.all_train_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="train", - transforms=self.custom_transform, - download=False, + self.root_dir, split="train", transforms=self.preprocess, download=False ) self.all_test_dataset = TropicalCycloneWindEstimation( - self.root_dir, - split="test", - transforms=self.custom_transform, - download=False, + self.root_dir, split="test", transforms=self.preprocess, download=False ) storm_ids = [] diff --git a/torchgeo/datamodules/etci2021.py b/torchgeo/datamodules/etci2021.py index 7efee68597e..0477c1f4012 100644 --- a/torchgeo/datamodules/etci2021.py +++ b/torchgeo/datamodules/etci2021.py @@ -25,11 +25,11 @@ class ETCI2021DataModule(pl.LightningDataModule): """ band_means = torch.tensor( - [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0] + [0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701] ) band_stds = torch.tensor( - [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1] + [0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622] ) def __init__( @@ -67,15 +67,15 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - image = sample["image"] - water_mask = sample["mask"][0].unsqueeze(0) - flood_mask = sample["mask"][1] - flood_mask = (flood_mask > 0).long() - - sample["image"] = torch.cat([image, water_mask], dim=0).float() + sample["image"] = sample["image"].float() sample["image"] /= 255.0 sample["image"] = self.norm(sample["image"]) - sample["mask"] = flood_mask + + if "mask" in sample: + flood_mask = sample["mask"][1] + flood_mask = (flood_mask > 0).long() + sample["mask"] = flood_mask + return sample def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/fair1m.py b/torchgeo/datamodules/fair1m.py index a22973fe98b..a8459c7ba92 100644 --- a/torchgeo/datamodules/fair1m.py +++ b/torchgeo/datamodules/fair1m.py @@ -9,7 +9,6 @@ import torch from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import FAIR1M from .utils import dataset_split @@ -85,9 +84,7 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - dataset = FAIR1M(self.root_dir, transforms=transforms) + dataset = FAIR1M(self.root_dir, transforms=self.preprocess) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/inria.py b/torchgeo/datamodules/inria.py index 88f502c59bd..9d5c177f39a 100644 --- a/torchgeo/datamodules/inria.py +++ b/torchgeo/datamodules/inria.py @@ -18,12 +18,6 @@ from ..samplers.utils import _to_tuple from .utils import dataset_split -DEFAULT_AUGS = K.AugmentationSequential( - K.RandomHorizontalFlip(p=0.5), - K.RandomVerticalFlip(p=0.5), - data_keys=["input", "mask"], -) - def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]: """Flatten wrapper.""" @@ -55,7 +49,6 @@ def __init__( test_split_pct: float = 0.1, patch_size: Union[int, Tuple[int, int]] = 512, num_patches_per_tile: int = 32, - augmentations: K.AugmentationSequential = DEFAULT_AUGS, predict_on: str = "test", ) -> None: """Initialize a LightningDataModule for InriaAerialImageLabeling based DataLoaders. @@ -70,7 +63,6 @@ def __init__( test_split_pct: What percentage of the dataset to use as a test set patch_size: Size of random patch from image and mask (height, width) num_patches_per_tile: Number of random patches per sample - augmentations: Default augmentations applied predict_on: Directory/Dataset of images to run inference on """ super().__init__() # type: ignore[no-untyped-call] @@ -81,7 +73,11 @@ def __init__( self.test_split_pct = test_split_pct self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size)) self.num_patches_per_tile = num_patches_per_tile - self.augmentations = augmentations + self.augmentations = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomVerticalFlip(p=0.5), + data_keys=["input", "mask"], + ) self.predict_on = predict_on self.random_crop = K.AugmentationSequential( K.RandomCrop(self.patch_size, p=1.0, keepdim=False), @@ -107,9 +103,16 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: return sample def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: - """Transform a single sample from the Dataset.""" - # RGB is int32 so divide by 255 - sample["image"] = sample["image"] / 255.0 + """Transform a single sample from the Dataset. + + Args: + sample: input image dictionary + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0) if "mask" in sample: diff --git a/torchgeo/datamodules/landcoverai.py b/torchgeo/datamodules/landcoverai.py index dcece944efe..b74f649fedf 100644 --- a/torchgeo/datamodules/landcoverai.py +++ b/torchgeo/datamodules/landcoverai.py @@ -92,10 +92,11 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - sample["image"] = sample["image"] / 255.0 - sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"].long() + 1 + sample["image"] /= 255.0 + + if "mask" in sample: + sample["mask"] = sample["mask"].long() + 1 return sample diff --git a/torchgeo/datamodules/loveda.py b/torchgeo/datamodules/loveda.py index 70e94c970f0..2bfddd953c8 100644 --- a/torchgeo/datamodules/loveda.py +++ b/torchgeo/datamodules/loveda.py @@ -54,7 +54,8 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - sample["image"] = sample["image"] / 255.0 + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 return sample diff --git a/torchgeo/datamodules/naip.py b/torchgeo/datamodules/naip.py index 928674dc5bd..acd8cdcbee7 100644 --- a/torchgeo/datamodules/naip.py +++ b/torchgeo/datamodules/naip.py @@ -7,6 +7,7 @@ import pytorch_lightning as pl from torch.utils.data import DataLoader +from torchvision.transforms import Compose from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples from ..samplers.batch import RandomBatchGeoSampler @@ -52,7 +53,7 @@ def __init__( self.num_workers = num_workers self.patch_size = patch_size - def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the NAIP Dataset. Args: @@ -61,10 +62,8 @@ def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed NAIP data """ - sample["image"] = sample["image"] / 255.0 sample["image"] = sample["image"].float() - - del sample["bbox"] + sample["image"] /= 255.0 return sample @@ -79,8 +78,18 @@ def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: """ sample["mask"] = sample["mask"].long()[0] - del sample["bbox"] + return sample + + def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Removes the bounding box property from a sample. + Args: + sample: dictionary with geographic metadata + + Returns + sample without the bbox property + """ + del sample["bbox"] return sample def prepare_data(self) -> None: @@ -100,14 +109,18 @@ def setup(self, stage: Optional[str] = None) -> None: """ # TODO: these transforms will be applied independently, this won't work if we # add things like random horizontal flip + + naip_transforms = Compose([self.preprocess, self.remove_bbox]) + chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox]) + chesapeake = Chesapeake13( - self.chesapeake_root_dir, transforms=self.chesapeake_transform + self.chesapeake_root_dir, transforms=chesapeak_transforms ) naip = NAIP( self.naip_root_dir, chesapeake.crs, chesapeake.res, - transforms=self.naip_transform, + transforms=naip_transforms, ) self.dataset = chesapeake & naip diff --git a/torchgeo/datamodules/nasa_marine_debris.py b/torchgeo/datamodules/nasa_marine_debris.py index e7d95921ae6..989deeb815f 100644 --- a/torchgeo/datamodules/nasa_marine_debris.py +++ b/torchgeo/datamodules/nasa_marine_debris.py @@ -9,7 +9,6 @@ import torch from torch import Tensor from torch.utils.data import DataLoader -from torchvision.transforms import Compose from ..datasets import NASAMarineDebris from .utils import dataset_split @@ -93,9 +92,7 @@ def setup(self, stage: Optional[str] = None) -> None: Args: stage: stage to set up """ - transforms = Compose([self.preprocess]) - - dataset = NASAMarineDebris(self.root_dir, transforms=transforms) + dataset = NASAMarineDebris(self.root_dir, transforms=self.preprocess) self.train_dataset, self.val_dataset, self.test_dataset = dataset_split( dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct ) diff --git a/torchgeo/datamodules/oscd.py b/torchgeo/datamodules/oscd.py index 441d680d32d..d11649a8a91 100644 --- a/torchgeo/datamodules/oscd.py +++ b/torchgeo/datamodules/oscd.py @@ -97,22 +97,19 @@ def __init__( self.num_patches_per_tile = num_patches_per_tile if bands == "rgb": - self.band_means = self.band_means[[3, 2, 1], None, None] - self.band_stds = self.band_stds[[3, 2, 1], None, None] - else: - self.band_means = self.band_means[:, None, None] - self.band_stds = self.band_stds[:, None, None] + self.band_means = self.band_means[[3, 2, 1]] + self.band_stds = self.band_stds[[3, 2, 1]] - self.norm = Normalize(self.band_means, self.band_stds) self.rcrop = K.AugmentationSequential( K.RandomCrop(patch_size), data_keys=["input", "mask"], same_on_batch=True ) self.padto = K.PadTo(pad_size) + self.norm = Normalize(self.band_means, self.band_stds) + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset.""" sample["image"] = sample["image"].float() - sample["mask"] = sample["mask"] sample["image"] = self.norm(sample["image"]) sample["image"] = torch.flatten(sample["image"], 0, 1) return sample diff --git a/torchgeo/datamodules/resisc45.py b/torchgeo/datamodules/resisc45.py index cdb1d9c324c..2c111b0431e 100644 --- a/torchgeo/datamodules/resisc45.py +++ b/torchgeo/datamodules/resisc45.py @@ -26,7 +26,6 @@ class RESISC45DataModule(pl.LightningDataModule): """ band_means = torch.tensor([0.36820969, 0.38083247, 0.34341029]) - band_stds = torch.tensor([0.20339924, 0.18524736, 0.18455448]) def __init__( diff --git a/torchgeo/datamodules/sen12ms.py b/torchgeo/datamodules/sen12ms.py index 7482a207047..e248d718e74 100644 --- a/torchgeo/datamodules/sen12ms.py +++ b/torchgeo/datamodules/sen12ms.py @@ -82,7 +82,7 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers - def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. Args: @@ -101,8 +101,9 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]: else: sample["image"][:] = sample["image"][:].clamp(0, 10000) / 10000 - sample["mask"] = sample["mask"][0, :, :].long() - sample["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, sample["mask"]) + if "mask" in sample: + sample["mask"] = sample["mask"][0, :, :].long() + sample["mask"] = torch.take(self.DFC2020_CLASS_MAPPING, sample["mask"]) return sample @@ -124,7 +125,7 @@ def setup(self, stage: Optional[str] = None) -> None: self.root_dir, split="train", bands=self.band_indices, - transforms=self.custom_transform, + transforms=self.preprocess, checksum=False, ) @@ -132,7 +133,7 @@ def setup(self, stage: Optional[str] = None) -> None: self.root_dir, split="test", bands=self.band_indices, - transforms=self.custom_transform, + transforms=self.preprocess, checksum=False, ) diff --git a/torchgeo/datamodules/so2sat.py b/torchgeo/datamodules/so2sat.py index e6fd9c4250d..968e81828ff 100644 --- a/torchgeo/datamodules/so2sat.py +++ b/torchgeo/datamodules/so2sat.py @@ -8,7 +8,7 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader -from torchvision.transforms import Compose +from torchvision.transforms import Compose, Normalize from ..datasets import So2Sat @@ -25,14 +25,6 @@ class So2SatDataModule(pl.LightningDataModule): band_means = torch.tensor( [ - -3.591224256609313e-05, - -7.658561276843396e-06, - 5.9373857475971184e-05, - 2.5166231537121083e-05, - 0.04420110659759328, - 0.25761027084996196, - 0.0007556743372573258, - 0.0013503466830024448, 0.12375696117681859, 0.1092774636368323, 0.1010855203267882, @@ -44,18 +36,10 @@ class So2SatDataModule(pl.LightningDataModule): 0.15428468872076637, 0.10905050699570007, ] - ).reshape(18, 1, 1) + ) band_stds = torch.tensor( [ - 0.17555201137417686, - 0.17556463274968204, - 0.45998793417834255, - 0.455988755730148, - 2.8559909213125763, - 8.324800606439833, - 2.4498757382563103, - 1.4647352984509094, 0.03958795985905458, 0.047778262752410296, 0.06636616706371974, @@ -67,29 +51,10 @@ class So2SatDataModule(pl.LightningDataModule): 0.09991773043519253, 0.08780632509122865, ] - ).reshape(18, 1, 1) - - # this reorders the bands to put S2 RGB first, then remainder of S2, then S1 - reindex_to_rgb_first = [ - 10, - 9, - 8, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - # 0, - # 1, - # 2, - # 3, - # 4, - # 5, - # 6, - # 7, - ] + ) + + # this reorders the bands to put S2 RGB first, then remainder of S2 + reindex_to_rgb_first = [2, 1, 0, 3, 4, 5, 6, 7, 8, 9] def __init__( self, @@ -117,6 +82,8 @@ def __init__( self.bands = bands self.unsupervised_mode = unsupervised_mode + self.norm = Normalize(self.band_means, self.band_stds) + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -126,8 +93,8 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: Returns: preprocessed sample """ - # sample["image"] = (sample["image"] - self.band_means) / self.band_stds sample["image"] = sample["image"].float() + sample["image"] = self.norm(sample["image"]) sample["image"] = sample["image"][self.reindex_to_rgb_first, :, :] if self.bands == "rgb": @@ -153,32 +120,42 @@ def setup(self, stage: Optional[str] = None) -> None: train_transforms = Compose([self.preprocess]) val_test_transforms = self.preprocess + s2bands = So2Sat.BAND_SETS["s2"] if not self.unsupervised_mode: self.train_dataset = So2Sat( - self.root_dir, split="train", transforms=train_transforms + self.root_dir, split="train", bands=s2bands, transforms=train_transforms ) self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=val_test_transforms + self.root_dir, + split="validation", + bands=s2bands, + transforms=val_test_transforms, ) self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=val_test_transforms + self.root_dir, + split="test", + bands=s2bands, + transforms=val_test_transforms, ) else: temp_train = So2Sat( - self.root_dir, split="train", transforms=train_transforms + self.root_dir, split="train", bands=s2bands, transforms=train_transforms ) self.val_dataset = So2Sat( - self.root_dir, split="validation", transforms=train_transforms + self.root_dir, + split="validation", + bands=s2bands, + transforms=train_transforms, ) self.test_dataset = So2Sat( - self.root_dir, split="test", transforms=train_transforms + self.root_dir, split="test", bands=s2bands, transforms=train_transforms ) self.train_dataset = cast( diff --git a/torchgeo/datamodules/ucmerced.py b/torchgeo/datamodules/ucmerced.py index c352c783a54..f95f75cf454 100644 --- a/torchgeo/datamodules/ucmerced.py +++ b/torchgeo/datamodules/ucmerced.py @@ -7,10 +7,9 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl -import torch import torchvision from torch.utils.data import DataLoader -from torchvision.transforms import Compose, Normalize +from torchvision.transforms import Compose from ..datasets import UCMerced @@ -25,10 +24,6 @@ class UCMercedDataModule(pl.LightningDataModule): Uses random train/val/test splits. """ - band_means = torch.tensor([0, 0, 0]) - - band_stds = torch.tensor([1, 1, 1]) - def __init__( self, root_dir: str, batch_size: int = 64, num_workers: int = 0, **kwargs: Any ) -> None: @@ -44,8 +39,6 @@ def __init__( self.batch_size = batch_size self.num_workers = num_workers - self.norm = Normalize(self.band_means, self.band_stds) - def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: """Transform a single sample from the Dataset. @@ -62,7 +55,6 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: sample["image"] = torchvision.transforms.functional.resize( sample["image"], size=(256, 256) ) - sample["image"] = self.norm(sample["image"]) return sample def prepare_data(self) -> None: diff --git a/torchgeo/datamodules/usavars.py b/torchgeo/datamodules/usavars.py index 7caa4862307..a3cedfd6251 100644 --- a/torchgeo/datamodules/usavars.py +++ b/torchgeo/datamodules/usavars.py @@ -3,10 +3,9 @@ """USAVars datamodule.""" -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Dict, Optional, Sequence import pytorch_lightning as pl -from torch import Tensor from torch.utils.data import DataLoader from ..datasets import USAVars @@ -24,7 +23,6 @@ def __init__( self, root_dir: str, labels: Sequence[str] = USAVars.ALL_LABELS, - transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, batch_size: int = 64, num_workers: int = 0, ) -> None: @@ -33,18 +31,28 @@ def __init__( Args: root_dir: The root argument passed to the USAVars Dataset classes labels: The labels argument passed to the USAVars Dataset classes - transforms: a function/transform that takes input sample and its target as - entry and returns a transformed version batch_size: The batch size to use in all created DataLoaders num_workers: The number of workers to use in all created DataLoaders """ super().__init__() self.root_dir = root_dir self.labels = labels - self.transforms = transforms self.batch_size = batch_size self.num_workers = num_workers + def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]: + """Transform a single sample from the Dataset. + + Args: + sample: dictionary containing image + + Returns: + preprocessed sample + """ + sample["image"] = sample["image"].float() + sample["image"] /= 255.0 + return sample + def prepare_data(self) -> None: """Make sure that the dataset is downloaded. @@ -58,13 +66,13 @@ def setup(self, stage: Optional[str] = None) -> None: This method is called once per GPU per run. """ self.train_dataset = USAVars( - self.root_dir, "train", self.labels, transforms=self.transforms + self.root_dir, "train", self.labels, transforms=self.preprocess ) self.val_dataset = USAVars( - self.root_dir, "val", self.labels, transforms=self.transforms + self.root_dir, "val", self.labels, transforms=self.preprocess ) self.test_dataset = USAVars( - self.root_dir, "test", self.labels, transforms=self.transforms + self.root_dir, "test", self.labels, transforms=self.preprocess ) def train_dataloader(self) -> DataLoader[Any]: