Skip to content

Commit

Permalink
Datamodule cleanup (#657)
Browse files Browse the repository at this point in the history
* Cleaning up preprocessing methods across DataModules

* Decoupled deleting the bbox with the other transforms in the GeoDataset DataModules

* Cleaning up how channel standardization is done

* Changing default conf for ETCI2021 and fixing So2Sat

* Forgot to update the indices.

* Change to use Normalize

* Remove default augs
  • Loading branch information
calebrob6 committed Jul 9, 2022
1 parent d137589 commit db3f183
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 153 deletions.
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
in_channels: 6
num_classes: 2
ignore_index: 0
datamodule:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
in_channels: 6
num_classes: 2
ignore_index: 0
datamodule:
Expand Down
41 changes: 27 additions & 14 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,34 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["mask"] = sample["mask"].squeeze()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0

if "mask" in sample:
sample["mask"] = sample["mask"].squeeze()
if self.use_prior_labels:
sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0)
sample["mask"] = F.normalize(
sample["mask"] + self.prior_smoothing_constant, p=1, dim=0
)
else:
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["mask"] = sample["mask"].long()

if self.use_prior_labels:
sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0)
sample["mask"] = F.normalize(
sample["mask"] + self.prior_smoothing_constant, p=1, dim=0
)
else:
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["mask"] = sample["mask"].long()
return sample

sample["image"] = sample["image"].float()
def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Removes the bounding box property from a sample.
del sample["bbox"]
Args:
sample: dictionary with geographic metadata
Returns
sample without the bbox property
"""
del sample["bbox"]
return sample

def nodata_check(
Expand Down Expand Up @@ -240,19 +250,22 @@ def setup(self, stage: Optional[str] = None) -> None:
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
self.remove_bbox,
]
)
val_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
self.remove_bbox,
]
)
test_transforms = Compose(
[
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
self.preprocess,
self.remove_bbox,
]
)

Expand Down
12 changes: 7 additions & 5 deletions torchgeo/datamodules/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers

def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
Expand All @@ -51,8 +51,10 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
sample["label"] = sample["label"].float()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0 # scale to [0, 1]
if "label" in sample:
sample["label"] = sample["label"].float()
return sample

def prepare_data(self) -> None:
Expand All @@ -73,10 +75,10 @@ def setup(self, stage: Optional[str] = None) -> None:
stage: stage to set up
"""
train_val_dataset = COWCCounting(
self.root_dir, split="train", transforms=self.custom_transform
self.root_dir, split="train", transforms=self.preprocess
)
self.test_dataset = COWCCounting(
self.root_dir, split="test", transforms=self.custom_transform
self.root_dir, split="test", transforms=self.preprocess
)
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
Expand Down
21 changes: 8 additions & 13 deletions torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.num_workers = num_workers
self.api_key = api_key

def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.
Args:
Expand All @@ -60,11 +60,13 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["label"] = torch.as_tensor(sample["label"]).float()
) # convert from grayscale to 3 channel
if "label" in sample:
sample["label"] = torch.as_tensor(sample["label"]).float()

return sample

Expand All @@ -77,7 +79,6 @@ def prepare_data(self) -> None:
TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=self.api_key is not None,
api_key=self.api_key,
)
Expand All @@ -99,17 +100,11 @@ def setup(self, stage: Optional[str] = None) -> None:
stage: stage to set up
"""
self.all_train_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=False,
self.root_dir, split="train", transforms=self.preprocess, download=False
)

self.all_test_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="test",
transforms=self.custom_transform,
download=False,
self.root_dir, split="test", transforms=self.preprocess, download=False
)

storm_ids = []
Expand Down
18 changes: 9 additions & 9 deletions torchgeo/datamodules/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class ETCI2021DataModule(pl.LightningDataModule):
"""

band_means = torch.tensor(
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701]
)

band_stds = torch.tensor(
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622]
)

def __init__(
Expand Down Expand Up @@ -67,15 +67,15 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
image = sample["image"]
water_mask = sample["mask"][0].unsqueeze(0)
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()

sample["image"] = torch.cat([image, water_mask], dim=0).float()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
sample["mask"] = flood_mask

if "mask" in sample:
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()
sample["mask"] = flood_mask

return sample

def prepare_data(self) -> None:
Expand Down
5 changes: 1 addition & 4 deletions torchgeo/datamodules/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import FAIR1M
from .utils import dataset_split
Expand Down Expand Up @@ -85,9 +84,7 @@ def setup(self, stage: Optional[str] = None) -> None:
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = FAIR1M(self.root_dir, transforms=transforms)
dataset = FAIR1M(self.root_dir, transforms=self.preprocess)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
Expand Down
27 changes: 15 additions & 12 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from ..samplers.utils import _to_tuple
from .utils import dataset_split

DEFAULT_AUGS = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)


def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Flatten wrapper."""
Expand Down Expand Up @@ -55,7 +49,6 @@ def __init__(
test_split_pct: float = 0.1,
patch_size: Union[int, Tuple[int, int]] = 512,
num_patches_per_tile: int = 32,
augmentations: K.AugmentationSequential = DEFAULT_AUGS,
predict_on: str = "test",
) -> None:
"""Initialize a LightningDataModule for InriaAerialImageLabeling based DataLoaders.
Expand All @@ -70,7 +63,6 @@ def __init__(
test_split_pct: What percentage of the dataset to use as a test set
patch_size: Size of random patch from image and mask (height, width)
num_patches_per_tile: Number of random patches per sample
augmentations: Default augmentations applied
predict_on: Directory/Dataset of images to run inference on
"""
super().__init__() # type: ignore[no-untyped-call]
Expand All @@ -81,7 +73,11 @@ def __init__(
self.test_split_pct = test_split_pct
self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
self.num_patches_per_tile = num_patches_per_tile
self.augmentations = augmentations
self.augmentations = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)
self.predict_on = predict_on
self.random_crop = K.AugmentationSequential(
K.RandomCrop(self.patch_size, p=1.0, keepdim=False),
Expand All @@ -107,9 +103,16 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
# RGB is int32 so divide by 255
sample["image"] = sample["image"] / 255.0
"""Transform a single sample from the Dataset.
Args:
sample: input image dictionary
Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0)

if "mask" in sample:
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0

sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].long() + 1
sample["image"] /= 255.0

if "mask" in sample:
sample["mask"] = sample["mask"].long() + 1

return sample

Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datamodules/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
sample["image"] /= 255.0

return sample

Expand Down
27 changes: 20 additions & 7 deletions torchgeo/datamodules/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples
from ..samplers.batch import RandomBatchGeoSampler
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
self.num_workers = num_workers
self.patch_size = patch_size

def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.
Args:
Expand All @@ -61,10 +62,8 @@ def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed NAIP data
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()

del sample["bbox"]
sample["image"] /= 255.0

return sample

Expand All @@ -79,8 +78,18 @@ def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""
sample["mask"] = sample["mask"].long()[0]

del sample["bbox"]
return sample

def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Removes the bounding box property from a sample.
Args:
sample: dictionary with geographic metadata
Returns
sample without the bbox property
"""
del sample["bbox"]
return sample

def prepare_data(self) -> None:
Expand All @@ -100,14 +109,18 @@ def setup(self, stage: Optional[str] = None) -> None:
"""
# TODO: these transforms will be applied independently, this won't work if we
# add things like random horizontal flip

naip_transforms = Compose([self.preprocess, self.remove_bbox])
chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox])

chesapeake = Chesapeake13(
self.chesapeake_root_dir, transforms=self.chesapeake_transform
self.chesapeake_root_dir, transforms=chesapeak_transforms
)
naip = NAIP(
self.naip_root_dir,
chesapeake.crs,
chesapeake.res,
transforms=self.naip_transform,
transforms=naip_transforms,
)
self.dataset = chesapeake & naip

Expand Down
Loading

0 comments on commit db3f183

Please sign in to comment.