Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Datamodule cleanup #657

Merged
merged 7 commits into from
Jul 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
encoder_weights: "imagenet"
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
in_channels: 6
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
num_classes: 2
ignore_index: 0
datamodule:
Expand Down
2 changes: 1 addition & 1 deletion tests/conf/etci2021.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ experiment:
encoder_weights: null
learning_rate: 1e-3
learning_rate_schedule_patience: 6
in_channels: 7
in_channels: 6
num_classes: 2
ignore_index: 0
datamodule:
Expand Down
41 changes: 27 additions & 14 deletions torchgeo/datamodules/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,24 +169,34 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
sample["mask"] = sample["mask"].squeeze()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0

if "mask" in sample:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
sample["mask"] = sample["mask"].squeeze()
if self.use_prior_labels:
sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0)
sample["mask"] = F.normalize(
sample["mask"] + self.prior_smoothing_constant, p=1, dim=0
)
else:
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["mask"] = sample["mask"].long()

if self.use_prior_labels:
sample["mask"] = F.normalize(sample["mask"].float(), p=1, dim=0)
sample["mask"] = F.normalize(
sample["mask"] + self.prior_smoothing_constant, p=1, dim=0
)
else:
if self.class_set == 5:
sample["mask"][sample["mask"] == 5] = 4
sample["mask"][sample["mask"] == 6] = 4
sample["mask"] = sample["mask"].long()
return sample

sample["image"] = sample["image"].float()
def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the "bbox" key is required for using the samples in the lightning trainer, however we need the "bbox" for predictions so I decoupled this logic from preprocess.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you remember why we had to remove bbox again? I feel like there was an option to keep it if we did something else differently...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm suspicious that it has something to do with collation, but the trainers/test definitely break if you don't remove it.

"""Removes the bounding box property from a sample.

del sample["bbox"]
Args:
sample: dictionary with geographic metadata

Returns
sample without the bbox property
"""
del sample["bbox"]
return sample

def nodata_check(
Expand Down Expand Up @@ -240,19 +250,22 @@ def setup(self, stage: Optional[str] = None) -> None:
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
self.remove_bbox,
]
)
val_transforms = Compose(
[
self.center_crop(self.patch_size),
self.nodata_check(self.patch_size),
self.preprocess,
self.remove_bbox,
]
)
test_transforms = Compose(
[
self.pad_to(self.original_patch_size, image_value=0, mask_value=0),
self.preprocess,
self.remove_bbox,
]
)

Expand Down
12 changes: 7 additions & 5 deletions torchgeo/datamodules/cowc.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
self.batch_size = batch_size
self.num_workers = num_workers

def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.

Args:
Expand All @@ -51,8 +51,10 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0, 1]
sample["label"] = sample["label"].float()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0 # scale to [0, 1]
if "label" in sample:
sample["label"] = sample["label"].float()
return sample

def prepare_data(self) -> None:
Expand All @@ -73,10 +75,10 @@ def setup(self, stage: Optional[str] = None) -> None:
stage: stage to set up
"""
train_val_dataset = COWCCounting(
self.root_dir, split="train", transforms=self.custom_transform
self.root_dir, split="train", transforms=self.preprocess
)
self.test_dataset = COWCCounting(
self.root_dir, split="test", transforms=self.custom_transform
self.root_dir, split="test", transforms=self.preprocess
)
self.train_dataset, self.val_dataset = random_split(
train_val_dataset,
Expand Down
21 changes: 8 additions & 13 deletions torchgeo/datamodules/cyclone.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
self.num_workers = num_workers
self.api_key = api_key

def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset.

Args:
Expand All @@ -60,11 +60,13 @@ def custom_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0 # scale to [0,1]
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = (
sample["image"].unsqueeze(0).repeat(3, 1, 1)
) # convert to 3 channel
sample["label"] = torch.as_tensor(sample["label"]).float()
) # convert from grayscale to 3 channel
if "label" in sample:
sample["label"] = torch.as_tensor(sample["label"]).float()

return sample

Expand All @@ -77,7 +79,6 @@ def prepare_data(self) -> None:
TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=self.api_key is not None,
api_key=self.api_key,
)
Expand All @@ -99,17 +100,11 @@ def setup(self, stage: Optional[str] = None) -> None:
stage: stage to set up
"""
self.all_train_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="train",
transforms=self.custom_transform,
download=False,
self.root_dir, split="train", transforms=self.preprocess, download=False
)

self.all_test_dataset = TropicalCycloneWindEstimation(
self.root_dir,
split="test",
transforms=self.custom_transform,
download=False,
self.root_dir, split="test", transforms=self.preprocess, download=False
)

storm_ids = []
Expand Down
18 changes: 9 additions & 9 deletions torchgeo/datamodules/etci2021.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class ETCI2021DataModule(pl.LightningDataModule):
"""

band_means = torch.tensor(
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701, 0]
[0.52253931, 0.52253931, 0.52253931, 0.61221701, 0.61221701, 0.61221701]
)

band_stds = torch.tensor(
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622, 1]
[0.35221376, 0.35221376, 0.35221376, 0.37364622, 0.37364622, 0.37364622]
)

def __init__(
Expand Down Expand Up @@ -67,15 +67,15 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
image = sample["image"]
water_mask = sample["mask"][0].unsqueeze(0)
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()

sample["image"] = torch.cat([image, water_mask], dim=0).float()
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = self.norm(sample["image"])
sample["mask"] = flood_mask

if "mask" in sample:
flood_mask = sample["mask"][1]
flood_mask = (flood_mask > 0).long()
sample["mask"] = flood_mask

return sample

def prepare_data(self) -> None:
Expand Down
5 changes: 1 addition & 4 deletions torchgeo/datamodules/fair1m.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import FAIR1M
from .utils import dataset_split
Expand Down Expand Up @@ -85,9 +84,7 @@ def setup(self, stage: Optional[str] = None) -> None:
Args:
stage: stage to set up
"""
transforms = Compose([self.preprocess])

dataset = FAIR1M(self.root_dir, transforms=transforms)
dataset = FAIR1M(self.root_dir, transforms=self.preprocess)
self.train_dataset, self.val_dataset, self.test_dataset = dataset_split(
dataset, val_pct=self.val_split_pct, test_pct=self.test_split_pct
)
Expand Down
27 changes: 15 additions & 12 deletions torchgeo/datamodules/inria.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
from ..samplers.utils import _to_tuple
from .utils import dataset_split

DEFAULT_AUGS = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)


def collate_wrapper(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Flatten wrapper."""
Expand Down Expand Up @@ -55,7 +49,6 @@ def __init__(
test_split_pct: float = 0.1,
patch_size: Union[int, Tuple[int, int]] = 512,
num_patches_per_tile: int = 32,
augmentations: K.AugmentationSequential = DEFAULT_AUGS,
predict_on: str = "test",
) -> None:
"""Initialize a LightningDataModule for InriaAerialImageLabeling based DataLoaders.
Expand All @@ -70,7 +63,6 @@ def __init__(
test_split_pct: What percentage of the dataset to use as a test set
patch_size: Size of random patch from image and mask (height, width)
num_patches_per_tile: Number of random patches per sample
augmentations: Default augmentations applied
predict_on: Directory/Dataset of images to run inference on
"""
super().__init__() # type: ignore[no-untyped-call]
Expand All @@ -81,7 +73,11 @@ def __init__(
self.test_split_pct = test_split_pct
self.patch_size = cast(Tuple[int, int], _to_tuple(patch_size))
self.num_patches_per_tile = num_patches_per_tile
self.augmentations = augmentations
self.augmentations = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=["input", "mask"],
)
self.predict_on = predict_on
self.random_crop = K.AugmentationSequential(
K.RandomCrop(self.patch_size, p=1.0, keepdim=False),
Expand All @@ -107,9 +103,16 @@ def patch_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
return sample

def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the Dataset."""
# RGB is int32 so divide by 255
sample["image"] = sample["image"] / 255.0
"""Transform a single sample from the Dataset.

Args:
sample: input image dictionary

Returns:
preprocessed sample
"""
sample["image"] = sample["image"].float()
sample["image"] /= 255.0
sample["image"] = torch.clip(sample["image"], min=0.0, max=1.0)

if "mask" in sample:
Expand Down
7 changes: 4 additions & 3 deletions torchgeo/datamodules/landcoverai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,11 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0

sample["image"] = sample["image"].float()
sample["mask"] = sample["mask"].long() + 1
sample["image"] /= 255.0

if "mask" in sample:
sample["mask"] = sample["mask"].long() + 1

return sample

Expand Down
3 changes: 2 additions & 1 deletion torchgeo/datamodules/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed sample
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()
sample["image"] /= 255.0

return sample

Expand Down
27 changes: 20 additions & 7 deletions torchgeo/datamodules/naip.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from ..datasets import NAIP, BoundingBox, Chesapeake13, stack_samples
from ..samplers.batch import RandomBatchGeoSampler
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(
self.num_workers = num_workers
self.patch_size = patch_size

def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Transform a single sample from the NAIP Dataset.

Args:
Expand All @@ -61,10 +62,8 @@ def naip_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Returns:
preprocessed NAIP data
"""
sample["image"] = sample["image"] / 255.0
sample["image"] = sample["image"].float()

del sample["bbox"]
sample["image"] /= 255.0

return sample

Expand All @@ -79,8 +78,18 @@ def chesapeake_transform(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""
sample["mask"] = sample["mask"].long()[0]

del sample["bbox"]
return sample

def remove_bbox(self, sample: Dict[str, Any]) -> Dict[str, Any]:
"""Removes the bounding box property from a sample.

Args:
sample: dictionary with geographic metadata

Returns
sample without the bbox property
"""
del sample["bbox"]
return sample

def prepare_data(self) -> None:
Expand All @@ -100,14 +109,18 @@ def setup(self, stage: Optional[str] = None) -> None:
"""
# TODO: these transforms will be applied independently, this won't work if we
# add things like random horizontal flip

naip_transforms = Compose([self.preprocess, self.remove_bbox])
chesapeak_transforms = Compose([self.chesapeake_transform, self.remove_bbox])

chesapeake = Chesapeake13(
self.chesapeake_root_dir, transforms=self.chesapeake_transform
self.chesapeake_root_dir, transforms=chesapeak_transforms
)
naip = NAIP(
self.naip_root_dir,
chesapeake.crs,
chesapeake.res,
transforms=self.naip_transform,
transforms=naip_transforms,
)
self.dataset = chesapeake & naip

Expand Down
Loading